-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pten] add concat pten kernel #38955
[pten] add concat pten kernel #38955
Conversation
Thanks for your contribution! |
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and | |||
limitations under the License. */ | |||
|
|||
#include "paddle/fluid/operators/math/concat_and_split.h" | |||
|
|||
#include "paddle/pten/kernels/cpu/concat_and_split.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议直接把concat_and_split.cc迁移过来,马上我们也要把它移过来了,可以下个PR再做
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
#include "paddle/pten/kernels/funcs/concat_funcs.h" | ||
namespace pten { | ||
|
||
DenseTensorMeta ConcatInferMeta(const std::vector<DenseTensorMeta>& x_meta, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里后面会改成的返回值作为输入参数指针的形式,和kernel保持一致
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到
|
||
DenseTensorMeta ConcatInferMeta(const std::vector<DenseTensorMeta>& x_meta, | ||
const Scalar& axis_scalar, | ||
bool is_runtime) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_runtime后面会用一个结构体封装起来
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
#include "paddle/pten/core/tensor_meta.h" | ||
namespace pten { | ||
|
||
// TODO(chentianyu03) use std::vector<DenseTensor> as InferMeta inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里后面会新增MetaTensor概念,作为inferMeta的输入
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,后续再调整
#include "paddle/fluid/framework/mixed_vector.h" | ||
#include "paddle/fluid/memory/malloc.h" | ||
#include "paddle/fluid/operators/math/concat_and_split.h" | ||
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[TODO] 这里需要梳理下platform下还依赖了哪些组件,是需要我们提前迁移的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
#include <vector> | ||
#include "gflags/gflags.h" | ||
#include "paddle/fluid/framework/mixed_vector.h" | ||
#include "paddle/fluid/memory/malloc.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里依赖malloc有点严重,看下我们有替代写法吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续优化
tmp_dev_ins_data = paddle::memory::Alloc(context, in_num * sizeof(T*)); | ||
auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph( | ||
inputs_data, in_num); | ||
paddle::memory::Copy(context.GetPlace(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
能否使用copy_kernel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy_kernle 是对tensor的操作,这里使用的是指针地址
|
||
#include "paddle/pten/kernels/concat_kernel.h" | ||
|
||
#include "paddle/fluid/framework/lod_tensor.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么还需要lod_tensor,这里的头文件确认下必要性
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里引用lod_tensor.h,是因为使用了AppendLoD 等辅助函数
@@ -0,0 +1,86 @@ | |||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2022
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
std::vector<pten::DenseTensor> pt_tensors; | ||
|
||
for(auto & t : tensors) { | ||
pt_tensors.push_back(*std::dynamic_pointer_cast<pten::DenseTensor>(t.impl())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里避免使用dynamic_cast,使用is_dense_tensor判断+static cast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已沟通,先保持现状
auto in_lod = paddle::framework::ConvertToLengthBasedLoD(x[i].lod()); | ||
paddle::framework::AppendLoD(out_lod, in_lod); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个kernel文件里引入较多fluid下的代码,建议评估一下迁移难度,如果可以尽量将依赖函数迁移到pten下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
70c7d9b
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Describe
add concat pten kernel