-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoParallel] convert distensor for eager custom op #59137
[AutoParallel] convert distensor for eager custom op #59137
Conversation
…b.com/wanghuancoder/Paddle into convert_disttensor_for_eager_custom_op
你的PR提交成功,感谢你对开源项目的贡献! |
…b.com/wanghuancoder/Paddle into convert_disttensor_for_eager_custom_op
@@ -44,7 +44,7 @@ void ShareTensor(PyObject* src, PyObject* dst) { | |||
} | |||
} | |||
|
|||
paddle::Tensor CastPyArg2Tensor(PyObject* obj, Py_ssize_t arg_pos) { | |||
paddle::Tensor& CastPyArg2Tensor(PyObject* obj, Py_ssize_t arg_pos) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么要改成引用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为需要对PyObject*里的Tensor做原位修改,如果返回复制对象,则不能做原位修改。
paddle::Tensor tensor = | ||
std::move(CastPyArg2Tensor(obj, i + 1)); // NOLINT | ||
ctx.EmplaceBackInput(std::move(tensor)); | ||
paddle::Tensor& tensor = CastPyArg2Tensor(obj, i + 1); // NOLINT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle::Tensor& tensor = CastPyArg2Tensor(obj, i + 1); // NOLINT | |
const paddle::Tensor& tensor = CastPyArg2Tensor(obj, i + 1); // NOLINT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx!
const phi::distributed::ProcessMesh* mesh = nullptr; | ||
if (InputsContainDistTensor(&mesh, *(ctx.AllMutableInput()))) { | ||
ctx.AllMutableInput()->clear(); | ||
for (size_t i = 0; i < inputs.size(); ++i) { | ||
const auto& input = inputs.at(i); | ||
// Parse op_type first, so that use i + 1 | ||
PyObject* obj = PyTuple_GET_ITEM(args, i + 1); | ||
// Emplace Py_None from python, this means optional inputs passed to C++, | ||
// use one un-initialized tensor to indicate both Tensor and | ||
// vector<Tensor> inputs. | ||
if (obj == Py_None) { | ||
VLOG(7) << "Custom operator add input " << input | ||
<< " to CustomOpKernelContext. Add un-initialized tensor " | ||
"because the optional input is None"; | ||
ctx.EmplaceBackInput(std::move(paddle::Tensor())); | ||
continue; | ||
} | ||
if (paddle::framework::detail::IsDuplicableVar(input)) { | ||
std::vector<paddle::Tensor> tensors = | ||
std::move(CastPyArg2VectorOfTensor(obj, i + 1, mesh)); // NOLINT | ||
ctx.EmplaceBackInputs(std::move(tensors)); | ||
VLOG(7) << "Custom operator add input " << input | ||
<< " to CustomOpKernelContext. Add vector<Tensor> size = " | ||
<< ctx.InputRangeAt(i).second - ctx.InputRangeAt(i).first; | ||
} else { | ||
paddle::Tensor& tensor = CastPyArg2Tensor(obj, i + 1); // NOLINT | ||
ConvertAllInputsToDistTensor(mesh, tensor); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的代码和前面对 input 的处理,重复度很高,可以合在一起,或者写成一个函数吗?
这里也用了很多 Tensor&,建议改成 const Tensor&
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我想了一下,没办法合并或者写成函数。
不能合并的原因是:可能传入的参数中,中间的一个Tensor是DistTensor,我需要重新全部扫描一遍,如果Tensor不是DistTensor则转为DistTensor,如果是DistTensor,则要求他们的mesh相等。必须重新来一遍。
不能写成函数的原因是:两个代码的结构虽然一样,但到具体执行的时候有差别。不适合1个函数,强行写成1个函数只能是可读性更差。
… convert_disttensor_for_eager_backward
…b.com/wanghuancoder/Paddle into convert_disttensor_for_eager_custom_op
…b.com/wanghuancoder/Paddle into convert_disttensor_for_eager_custom_op
… convert_disttensor_for_eager_custom_op
… convert_disttensor_for_eager_custom_op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Description
自定义算子,如果有1个Tensor为DistTensor,则全部转换为DistTensor
Pcard-73145