-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.3】为 Paddle 新增 masked_fill API #57355
Changes from 31 commits
4925ae4
fbec78a
3cfde41
37dd891
793fcdc
0c1e8ce
faa459a
0a4d8d4
746a818
d48d73d
7a0ef26
de8cd75
fb1a2ea
810575f
1b213bf
df6013e
82bb307
1349206
9ef08e4
5f463ad
fe74eff
4d394a0
211d655
0703cd0
1c0abb8
f9bdfa0
f5db8f6
0f48c13
c8040ce
ccb535c
2317773
f678e39
d6b7bd7
bccd1f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -4561,6 +4561,74 @@ def moveaxis(x, source, destination, name=None): | |||||||||||||||||||||||||||||||||||||
return out | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def masked_fill(x, mask, value, name=None): | ||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||
Fills elements of self tensor with value where mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor. | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||
x (Tensor) : The Destination Tensor. Supported data types are float, | ||||||||||||||||||||||||||||||||||||||
double, int, int64_t,float16 and bfloat16. | ||||||||||||||||||||||||||||||||||||||
mask (Tensor): The boolean tensor indicate the position to be filled. | ||||||||||||||||||||||||||||||||||||||
The data type of mask must be bool. | ||||||||||||||||||||||||||||||||||||||
value (Scalar or 0-D Tensor): The value used to fill the target tensor. | ||||||||||||||||||||||||||||||||||||||
Supported data types are float, double, int, int64_t,float16 and bfloat16. | ||||||||||||||||||||||||||||||||||||||
name(str, optional): The default value is None. Normally there is no | ||||||||||||||||||||||||||||||||||||||
need for user to set this property. For more information, please | ||||||||||||||||||||||||||||||||||||||
refer to :ref:`api_guide_Name`. | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||||||||||||
Tensor, same dimention and dtype with x. | ||||||||||||||||||||||||||||||||||||||
Examples: | ||||||||||||||||||||||||||||||||||||||
.. code-block:: python | ||||||||||||||||||||||||||||||||||||||
>>> # doctest: +REQUIRES(env:GPU) | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||
>>> import paddle | ||||||||||||||||||||||||||||||||||||||
>>> x = paddle.ones((3, 3), dtype="float32") | ||||||||||||||||||||||||||||||||||||||
>>> mask = paddle.to_tensor([[True, True, False]]) | ||||||||||||||||||||||||||||||||||||||
>>> print(mask) | ||||||||||||||||||||||||||||||||||||||
Tensor(shape=[1, 3], dtype=bool, place=Place(gpu:0), stop_gradient=True, | ||||||||||||||||||||||||||||||||||||||
[[True , True , False]]) | ||||||||||||||||||||||||||||||||||||||
>>> out = paddle.masked_fill(x, mask, 2) | ||||||||||||||||||||||||||||||||||||||
>>> print(out) | ||||||||||||||||||||||||||||||||||||||
Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, | ||||||||||||||||||||||||||||||||||||||
[[2., 2., 1.], | ||||||||||||||||||||||||||||||||||||||
[2., 2., 1.], | ||||||||||||||||||||||||||||||||||||||
[2., 2., 1.]]) | ||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||
if np.isscalar(value): | ||||||||||||||||||||||||||||||||||||||
value = paddle.full([], value, x.dtype) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
mask = paddle.logical_not(mask) | ||||||||||||||||||||||||||||||||||||||
out = paddle.where(mask, x, value) | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是否可以把value和x对调,省去一个not op操作 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里应该是不能对调的,paddle.where(cond, x, y) 在 inplace 的时候对boardcast的处理如下: zeros_like_x = paddle.zeros_like(x)
zeros_like_y = paddle.zeros_like(y)
zeros_like_condition = paddle.zeros_like(condition)
zeros_like_condition = paddle.cast(zeros_like_condition, x.dtype)
cast_cond = paddle.cast(condition, x.dtype)
broadcast_zeros = paddle.add(zeros_like_x, zeros_like_y)
broadcast_zeros = paddle.add(broadcast_zeros, zeros_like_condition)
broadcast_x = x.add_(broadcast_zeros)
broadcast_y = paddle.add(y, broadcast_zeros)
broadcast_condition = paddle.add(cast_cond, broadcast_zeros)
broadcast_condition = paddle.cast(broadcast_condition, 'bool') 其中 |
||||||||||||||||||||||||||||||||||||||
return out | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
@inplace_apis_in_dygraph_only | ||||||||||||||||||||||||||||||||||||||
def masked_fill_(x, mask, value, name=None): | ||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||
Inplace version of ``masked_fill`` API, the output Tensor will be inplaced with input ``x``. | ||||||||||||||||||||||||||||||||||||||
Please refer to :ref:`api_paddle_masked_fill`. | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
Examples: | ||||||||||||||||||||||||||||||||||||||
.. code-block:: python | ||||||||||||||||||||||||||||||||||||||
>>> # doctest: +REQUIRES(env:GPU) | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. all done |
||||||||||||||||||||||||||||||||||||||
>>> import paddle | ||||||||||||||||||||||||||||||||||||||
>>> x = paddle.ones((3, 3), dtype="float32") | ||||||||||||||||||||||||||||||||||||||
>>> mask = paddle.to_tensor([[True, False, False]]) | ||||||||||||||||||||||||||||||||||||||
>>> out = paddle.masked_fill_(x, mask, 2) | ||||||||||||||||||||||||||||||||||||||
>>> print(out) | ||||||||||||||||||||||||||||||||||||||
Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, | ||||||||||||||||||||||||||||||||||||||
[[2., 1., 1.], | ||||||||||||||||||||||||||||||||||||||
[2., 1., 1.], | ||||||||||||||||||||||||||||||||||||||
[2., 1., 1.]]) | ||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||
if np.isscalar(value): | ||||||||||||||||||||||||||||||||||||||
value = paddle.full([], value, x.dtype) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
mask = paddle.logical_not(mask) | ||||||||||||||||||||||||||||||||||||||
out = paddle.where_(mask, x, value) | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上,是否可以通过调换x , value 来避免额外的not操作 |
||||||||||||||||||||||||||||||||||||||
return out | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def non_negative_axis(arr, axis): | ||||||||||||||||||||||||||||||||||||||
ndim = len(arr.shape) | ||||||||||||||||||||||||||||||||||||||
if axis >= 0: | ||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参数说明中写明支持的dtype,且要和设计文档中的一致;
Scaler
->Scalar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done