Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new operation: BroadcastTensorsOp, test=develop #33294

Merged
merged 1 commit into from
Jun 23, 2021

Conversation

jim19930609
Copy link
Contributor

@jim19930609 jim19930609 commented Jun 2, 2021

PR types

New features

PR changes

OPs

Describe

  • Added BroadcastTensorsOp and BroadcastTensorsGradOp
  • Behavior of this Op is aligned with torch.broadcast_tensors
  • Supported both CPU and GPU places

API:

paddle.broadcast_tensors() or paddle.tensor.broadcast_tensors()

Code position:

python/paddle/tensor/manipulation.py

Example:

import paddle
import numpy as np

x1 = paddle.to_tensor(np.random.random([1, 2, 3, 4]).astype('float32'))
x2 = paddle.to_tensor(np.random.random([1, 2, 1, 4]).astype('float32'))
x3 = paddle.to_tensor(np.random.random([1, 1, 3, 1]).astype('float32'))
out1, out2, out3 = paddle.broadcast_tensors(input=[x1, x2, x3])
# out1, out2, out3: tensors broadcasted from x1, x2, x3 with shape [1,2,3,4]

CN Doc PR: PaddlePaddle/docs#3606

image

@paddle-bot-old
Copy link

paddle-bot-old bot commented Jun 2, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot-old
Copy link

Sorry to inform you that ac0d714's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

Copy link
Contributor

@JiabinYang JiabinYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments

TCChenlong
TCChenlong previously approved these changes Jun 21, 2021
Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

# out1, out2, out3: tensors broadcasted from x1, x2, x3 with shape [1,2,3,4]
"""

num_inputs = len(input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个值一定需要从python端获取吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用在:core.ops.broadcast_tensors(input, num_inputs)
理论上不需要额外传一个num_inputs进去,可以在C++层从inputs中动态获得。

但是这个"OutNum"是自动代码生成产生的,用来指导创造输出VarBase的数量。不更改自动代码生成逻辑的前提下,很难改这部分逻辑:
image

类似逻辑的split_op也是从python中获取的"num",传入core.ops.split()
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

与秋良确认了一下:

  1. 目前自动代码生成确实需要产生一个OutNum,后面如果出现较多类似broadcast_tensors的情况,考虑更新自动代码生成逻辑
  2. 目前来看多传入一个int num_inputs对性能影响不大

lanxianghit
lanxianghit previously approved these changes Jun 22, 2021
Copy link
Contributor

@lanxianghit lanxianghit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

wanghuancoder
wanghuancoder previously approved these changes Jun 22, 2021
Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

TCChenlong
TCChenlong previously approved these changes Jun 23, 2021
TCChenlong
TCChenlong previously approved these changes Jun 23, 2021
wanghuancoder
wanghuancoder previously approved these changes Jun 23, 2021
lanxianghit
lanxianghit previously approved these changes Jun 23, 2021
@JiabinYang JiabinYang merged commit affddfa into PaddlePaddle:develop Jun 23, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants