Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Prim】Optimize composite rule by making scalar shape as 1 #51960

Merged
merged 9 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None):
'x4', shape=weight.shape, dtype=str(weight.dtype)
)
x5 = paddle.static.data('x5', shape=bias.shape, dtype=str(bias.dtype))
if attrs.use_global_stats is None:
attrs.use_global_stats = not attrs.training
trainable_statistics = False
else:
trainable_statistics = not attrs.use_global_stats

use_run_stat = (
(not attrs.training) and (not trainable_statistics)
) or attrs.use_global_stats
y = fn(
x1,
x2,
Expand All @@ -177,16 +186,27 @@ def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None):
blocks[0].ops[0].output_names, blocks[0].ops[0].output_arg_names
)
)
vars_list = [
names[key]
for key in [
"Y",
"MeanOut",
"VarianceOut",
"SavedMean",
"SavedVariance",

if not use_run_stat:
vars_list = [
names[key]
for key in [
"Y",
"MeanOut",
"VarianceOut",
"SavedMean",
"SavedVariance",
]
]
else:
vars_list = [
names[key]
for key in [
"Y",
"MeanOut",
"VarianceOut",
]
]
]

fwd_ops = [op.type for op in blocks[0].ops]
# Ensure that batch_norm in original block
Expand All @@ -202,21 +222,36 @@ def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None):
exe.run(startup_program)

# indeed SavedVariance is 1/sqrt(batch_var+eps)
Y, MeanOut, VarianceOut, SavedMean, SavedVariance = exe.run(
main_program,
feed={
'x1': inputs,
'x2': running_mean,
'x3': running_variance,
'x4': weight,
'x5': bias,
},
fetch_list=vars_list,
)
if not use_run_stat:
Y, MeanOut, VarianceOut, SavedMean, SavedVariance = exe.run(
main_program,
feed={
'x1': inputs,
'x2': running_mean,
'x3': running_variance,
'x4': weight,
'x5': bias,
},
fetch_list=vars_list,
)
else:
Y, MeanOut, VarianceOut = exe.run(
main_program,
feed={
'x1': inputs,
'x2': running_mean,
'x3': running_variance,
'x4': weight,
'x5': bias,
},
fetch_list=vars_list,
)
paddle.disable_static()
core._set_prim_all_enabled(False)

return Y, MeanOut, VarianceOut, SavedMean, SavedVariance
if not use_run_stat:
return Y, MeanOut, VarianceOut, SavedMean, SavedVariance
else:
return Y, MeanOut, VarianceOut


class TestCompositeBatchNorm(unittest.TestCase):
Expand Down
43 changes: 26 additions & 17 deletions python/paddle/incubate/autograd/composite_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,10 @@ def composite_batchnorm(

# reserve_space is not needed in composite rule, but still ruturn None to keep same as phi op definition.
reserve_space = None

return y, run_mean_, run_var_, batch_mean_, inv_std_, reserve_space
if not use_run_stat:
return y, run_mean_, run_var_, batch_mean_, inv_std_, reserve_space
else:
return y, run_mean_, run_var_, None, None, reserve_space


@REGISTER_COMPOSITE('layer_norm')
Expand Down Expand Up @@ -188,12 +190,13 @@ def gelu_composite(x, approximate):
0.70710678118654752440 # /* 1/sqrt(2) */ copy from gelu-kernel.cc
)
M_2_SQRTPI = 1.12837916709551257390 # /* 2/sqrt(pi) */
one = ones(x.shape, x.dtype)
half = full(x.shape, 0.5, x.dtype)
full_shape = x.shape if len(x.shape) == 0 else [1]
one = ones(full_shape, x.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

len(x.shape) == 0 this condition won't exist, better to check other same places

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0D cases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0D cases

half = full(full_shape, 0.5, x.dtype)
if approximate:
# gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3})))
kAlpha = full(x.shape, M_2_SQRTPI * M_SQRT1_2, x.dtype)
GELU_CONSTANT = full(x.shape, 0.044715, x.dtype)
kAlpha = full(full_shape, M_2_SQRTPI * M_SQRT1_2, x.dtype)
GELU_CONSTANT = full(full_shape, 0.044715, x.dtype)
tanh_out = tanh(kAlpha * (x + GELU_CONSTANT * x * x * x))
out = x * half * (one + tanh_out)
return out
Expand All @@ -215,7 +218,7 @@ def mean_composite(x, axis, keepdim):
operator.mul, [x.shape[axis] for axis in axes]
)
norm = fill_constant(
shape=sum_x.shape,
shape=x.shape if len(x.shape) == 0 else [1],
value=value_to_fill,
dtype=sum_x.dtype,
)
Expand Down Expand Up @@ -321,7 +324,9 @@ def flatten_contiguous_range_composite(x, start_axis, stop_axis):
start_dim = start_axis if len(shape_in) != 0 else 0
end_dim = stop_axis if len(shape_in) != 0 else 0
assert start_dim <= end_dim
if len(shape_in) == 0 or start_dim == end_dim:
if len(shape_in) == 0:
return reshape(x, shape=[1]), None
if start_dim == end_dim:
return reshape(x, shape=shape_in), None
slice_numel = 1
for i in range(start_dim, end_dim + 1):
Expand Down Expand Up @@ -377,7 +382,7 @@ def bernoulli(shape, dtype, p, seed=0):
return cast(
greater_equal(
uniform(shape, new_dtype, min=0.0, max=1.0, seed=seed),
fill_constant(shape, new_dtype, p),
fill_constant(shape if len(shape) == 0 else [1], new_dtype, p),
),
dtype,
)
Expand All @@ -394,16 +399,17 @@ def hard_swish_composite(x):
offset = 3.0
threshold = 6.0
scale = 6.0
full_shape = x.shape if len(x.shape) == 0 else [1]
res = (
minimum(
maximum(
x + full(x.shape, offset, dtype=x.dtype),
full(x.shape, 0.0, dtype=x.dtype),
x + full(full_shape, offset, dtype=x.dtype),
full(full_shape, 0.0, dtype=x.dtype),
),
full(x.shape, threshold, dtype=x.dtype),
full(full_shape, threshold, dtype=x.dtype),
)
* x
/ full(x.shape, scale, dtype=x.dtype)
/ full(full_shape, scale, dtype=x.dtype)
)
return res

Expand Down Expand Up @@ -504,7 +510,7 @@ def sqrt_composite(x):
define composite rule of op sqrt
res = pow(x, 0.5)
"""
y = full(x.shape, 0.5, x.dtype)
y = full(x.shape if len(x.shape) == 0 else [1], 0.5, x.dtype)
res = pow(x, y)
return res

Expand All @@ -516,7 +522,7 @@ def pow_composite(x, y):
res = x^y
"""
if isinstance(y, (int, float)):
y = full([1], y, x.dtype)
y = full(x.shape if len(x.shape) == 0 else [1], y, x.dtype)
res = pow(x, y)
return res

Expand All @@ -525,7 +531,10 @@ def pow_composite(x, y):
def relu_composite(x):
"""define composite rule of op relu."""
# relu(x) = max(x, 0)
return maximum(x, zeros_like(x))
if len(x.shape) == 0:
return maximum(x, full(x.shape, 0.0, x.dtype))
else:
return maximum(x, full([1], 0.0, x.dtype))


@REGISTER_COMPOSITE('unsqueeze2')
Expand All @@ -552,5 +561,5 @@ def unsqueeze_composite(x, axis):
def rsqrt_composite(x):
"""define composite rule of op rsqrt."""
# rsqrt(x) = x^(-0.5)
y = full(x.shape, -0.5, x.dtype)
y = full(x.shape if len(x.shape) == 0 else [1], -0.5, x.dtype)
return pow(x, y)
15 changes: 7 additions & 8 deletions python/paddle/incubate/autograd/primx.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,14 +633,13 @@ def expand_nested_list(xs):
f'when replace origin op {op_name} with composite rule, origin out dtype should be equal to new out dtype, '
f'but orig_out: {orig_out.name}.dtype={orig_out.dtype} and new_out: {new_out.name}.dtype={new_out.dtype}'
)
if orig_out.shape and new_out.shape:
assert (
-1 not in new_out.shape
), f'when replace origin op {op_name} with composite rule, composite out shape has -1.'
assert orig_out.shape == new_out.shape, (
f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, '
f'but orig_out: {orig_out.name}.shape={orig_out.shape} and new_out: {new_out.name}.shape={new_out.shape}'
)
assert (
-1 not in new_out.shape
), f'when replace origin op {op_name} with composite rule, composite out shape has -1.'
assert orig_out.shape == new_out.shape, (
f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, '
f'but orig_out: {orig_out.name}.shape={orig_out.shape} and new_out: {new_out.name}.shape={new_out.shape}'
)
assert not (orig_out is None) ^ (
new_out is None
), "orig_out and new_out should match."
Expand Down