Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim] add batch_norm composite rule #49894

Merged
merged 1 commit into from
Feb 7, 2023

Conversation

cyber-pioneer
Copy link
Contributor

@cyber-pioneer cyber-pioneer commented Jan 17, 2023

PR types

New features

PR changes

Others

Describe

Add batch_norm composite rule and support batch_norm replaced by prime ops.

Composite rule

mean = reduce_sum(x) / nhw
variance = reduce_sum(x * x) / nhw - mean * mean
std_variance_inv = rsqrt(variance + epsilon), shape = [c]

y = scale * (x - mean) * std_variance_inv + bias, shape = [n, c, h, w]

moving_mean = moving_mean * momentum + (1.0 - momentum) * mean, shape = [c]
moving_variance = moving_variance * momentum + (1.0 - momentum) * variance, shape = [c]

@paddle-bot
Copy link

paddle-bot bot commented Jan 17, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@JiabinYang JiabinYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some comments

value_table[new_out.name] = new_out
to_bind[orig_out.name] = new_out.name
to_bind_rev[new_out.name] = orig_out.name
if new_out is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may cause some problem, not safe

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove origin if it's not none

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

broadcast_to(reshape(bias, stats_shape), x_hat.shape),
)

batch_mean_ = assign(batch_mean)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Notice

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@cyber-pioneer cyber-pioneer force-pushed the composite_debug branch 2 times, most recently from 7d7fa8b to b3675ee Compare January 19, 2023 02:09
@cyber-pioneer cyber-pioneer force-pushed the composite_debug branch 3 times, most recently from 51fef9d to 16d1225 Compare February 3, 2023 05:26
multiply(momentum, run_mean),
multiply(
subtract(ones(run_mean.shape, run_mean.dtype), momentum),
reshape(batch_mean, run_mean.shape),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个reshape 是不是不需要

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

以去掉不必要的算子

multiply(momentum, run_var),
multiply(
subtract(ones(run_var.shape, run_var.dtype), momentum),
reshape(batch_var, run_var.shape),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还有这个?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

reduce_axes,
keepdim=True,
)
x_hat = divide(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是直接用 运算符就行, 逻辑可读性会更强

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@cyber-pioneer cyber-pioneer force-pushed the composite_debug branch 5 times, most recently from 2187998 to 3035b49 Compare February 5, 2023 12:05
Copy link
Contributor

@JiabinYang JiabinYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some comments, and add some test res

@@ -448,6 +448,30 @@ def _test_use_sync(value):
__sync_stat_with_flag(value)


# ops in forward_blacklisk will not be replaced by composite ops.
prim_config = {"forward_blacklist": []}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to be a dict?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or merge this into backward blacklist

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

later forward and backward blacklist will be united

1 if i in reduce_axes else s for i, s in enumerate(x.shape)
)

batch_mean = zeros(run_mean.shape, run_mean.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why make batched_mean as zeros?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is init value of batched_mean. If training, zeros will be replaced by actual bach value. If testing, batched_mean is unuseful.

fetch_list=[y],
)
paddle.disable_static()
core._set_prim_forward_enabled(False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_set_prim_all_enabled(False)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

data_format,
use_global_stats,
)
gradients = paddle.grad(res, x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feed random v or not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in forward fn, output of bn will be mupliplied by random vector, this is equal to feed random v.

@cyber-pioneer cyber-pioneer force-pushed the composite_debug branch 4 times, most recently from a03cba0 to 6e2d5b0 Compare February 6, 2023 12:23
move composite test case

remove unuseful var

add composite op blacklist
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@zyfncg zyfncg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for op_compat.yaml

@JiabinYang JiabinYang merged commit 9b3a41b into PaddlePaddle:develop Feb 7, 2023
@cyber-pioneer cyber-pioneer changed the title add batch_norm composite rule [Prim] add batch_norm composite rule Feb 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants