-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Prim] add batch_norm composite rule #49894
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some comments
value_table[new_out.name] = new_out | ||
to_bind[orig_out.name] = new_out.name | ||
to_bind_rev[new_out.name] = orig_out.name | ||
if new_out is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may cause some problem, not safe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove origin if it's not none
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
broadcast_to(reshape(bias, stats_shape), x_hat.shape), | ||
) | ||
|
||
batch_mean_ = assign(batch_mean) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add Notice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
7d7fa8b
to
b3675ee
Compare
51fef9d
to
16d1225
Compare
multiply(momentum, run_mean), | ||
multiply( | ||
subtract(ones(run_mean.shape, run_mean.dtype), momentum), | ||
reshape(batch_mean, run_mean.shape), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个reshape 是不是不需要
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
以去掉不必要的算子
multiply(momentum, run_var), | ||
multiply( | ||
subtract(ones(run_var.shape, run_var.dtype), momentum), | ||
reshape(batch_var, run_var.shape), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
还有这个?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
reduce_axes, | ||
keepdim=True, | ||
) | ||
x_hat = divide( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是直接用 运算符就行, 逻辑可读性会更强
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
2187998
to
3035b49
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some comments, and add some test res
@@ -448,6 +448,30 @@ def _test_use_sync(value): | |||
__sync_stat_with_flag(value) | |||
|
|||
|
|||
# ops in forward_blacklisk will not be replaced by composite ops. | |||
prim_config = {"forward_blacklist": []} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to be a dict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or merge this into backward blacklist
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
later forward and backward blacklist will be united
1 if i in reduce_axes else s for i, s in enumerate(x.shape) | ||
) | ||
|
||
batch_mean = zeros(run_mean.shape, run_mean.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why make batched_mean as zeros?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is init value of batched_mean. If training, zeros will be replaced by actual bach value. If testing, batched_mean is unuseful.
fetch_list=[y], | ||
) | ||
paddle.disable_static() | ||
core._set_prim_forward_enabled(False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_set_prim_all_enabled(False)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
data_format, | ||
use_global_stats, | ||
) | ||
gradients = paddle.grad(res, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feed random v or not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in forward fn, output of bn will be mupliplied by random vector, this is equal to feed random v.
a03cba0
to
6e2d5b0
Compare
move composite test case remove unuseful var add composite op blacklist
6e2d5b0
to
ef1f8a1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for op_compat.yaml
PR types
New features
PR changes
Others
Describe
Add batch_norm composite rule and support batch_norm replaced by prime ops.
Composite rule
mean = reduce_sum(x) / nhw
variance = reduce_sum(x * x) / nhw - mean * mean
std_variance_inv = rsqrt(variance + epsilon), shape = [c]
y = scale * (x - mean) * std_variance_inv + bias, shape = [n, c, h, w]
moving_mean = moving_mean * momentum + (1.0 - momentum) * mean, shape = [c]
moving_variance = moving_variance * momentum + (1.0 - momentum) * variance, shape = [c]