-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multiplex operator #4064
Add multiplex operator #4064
Conversation
auto num_ins = ins.size(); | ||
PADDLE_ENFORCE(num_ins > 2, | ||
"multiplex operator should have more than 2 inputs."); | ||
PADDLE_ENFORCE_EQ(ins[0]->dims().size(), 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also have to check the index in ins[0], index in ins[0] must less than ins[0]->dims()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Add the index check in the forward compute function.
paddle/operators/multiplex_op.cc
Outdated
"Input(Out@GRAD) shouldn't be null."); | ||
auto d_ins = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X")); | ||
auto ins = ctx.MultiInput<Tensor>("X"); | ||
// don;t compute gradient for index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don;t --> don't
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/multiplex_op.cu
Outdated
auto index = index_t_cpu.data<T>(); | ||
for (auto i = 0; i < rows; i++) { | ||
int k = (int)index[i] + 1; | ||
cudaMemcpy(out->data<T>() + i * cols, ins[k]->data<T>() + i * cols, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use cuda stream.
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
platform::GPUPlace place = boost::Get<platform::GPUPlace>(ctx.GetPlace());
memory::Copy(place, out->data<T>() + i * cols, place, ins[k]->data<T>() + i * cols, cols * sizeof(T), stream);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/multiplex_op.h
Outdated
auto cols = ins[1]->dims()[1]; | ||
for (auto i = 0; i < rows; i++) { | ||
int k = (int)index[i] + 1; | ||
memcpy(out->data<T>() + i * cols, ins[k]->data<T>() + i * cols, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can combine cpu code and cuda code in one file.
template <typename Place, typename T>
class MultiplexKernel : public framework::OpKernel
We can use
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for set cpu/gpu to zero
And we can use
memory::Copy
for both cpu/gpu copy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that merge CPU/GPU code together is not a good idea here. I make a mistake.
If CPU and GPU both use Eigen, we can reuse codes easily. But if not, it's actually better to split CPU and GPU implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. split CPU/GPU code again.
paddle/operators/multiplex_op.cc
Outdated
|
||
class MultiplexOp : public framework::OperatorWithKernel { | ||
public: | ||
MultiplexOp(const std::string &type, const framework::VariableNameMap &inputs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use using framework::OperatorWithKernel:: OperatorWithKernel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the valuable comments. Please review the changes
paddle/operators/multiplex_op.cc
Outdated
|
||
class MultiplexOp : public framework::OperatorWithKernel { | ||
public: | ||
MultiplexOp(const std::string &type, const framework::VariableNameMap &inputs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified
paddle/operators/multiplex_op.cc
Outdated
"Input(Out@GRAD) shouldn't be null."); | ||
auto d_ins = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X")); | ||
auto ins = ctx.MultiInput<Tensor>("X"); | ||
// don;t compute gradient for index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/multiplex_op.cu
Outdated
auto index = index_t_cpu.data<T>(); | ||
for (auto i = 0; i < rows; i++) { | ||
int k = (int)index[i] + 1; | ||
cudaMemcpy(out->data<T>() + i * cols, ins[k]->data<T>() + i * cols, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/multiplex_op.h
Outdated
auto cols = ins[1]->dims()[1]; | ||
for (auto i = 0; i < rows; i++) { | ||
int k = (int)index[i] + 1; | ||
memcpy(out->data<T>() + i * cols, ins[k]->data<T>() + i * cols, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Resolve #4010