-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support stride #55156
Support stride #55156
Conversation
… support_stride
… support_stride
… support_stride
… support_stride2
… support_stride2
… support_stride2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
backup.emplace_back(nullptr); | ||
} else { | ||
backup.emplace_back(t); | ||
t = new phi::DenseTensor(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里new出来的对象,在哪里做的delete?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个计划在TransStride里做的,忘记了。我给加上了感谢!
"squeeze", | ||
"unsqueeze", | ||
"reshape", | ||
"flatten", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续这个名单是否会考虑类似文件前面的PREFIX_TENSOR_NAME
一样的全局变量标记一下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个名单的意思是,这4个C++ API,本身已经支持stride。不需要再通过ProcessStrideBackup和TransStride来特殊处理inplace了。暂时没有再修改的计划。
@@ -210,6 +217,9 @@ DDim make_ddim(std::initializer_list<int64_t> dims); | |||
|
|||
template <typename T = int64_t> | |||
std::vector<T> vectorize(const DDim& ddim) { | |||
if (ddim.size() == -1) { | |||
return std::vector<T>({0}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
若 shape = [0]
是否意味着其 ddim.size() == -1 ?这里是特殊标记了一下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-1是我新引入的,表示ddim还没有初始化。
shape=[0]的意思是已经初始化过了size()==1~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
单测超时时间设置
if (!strided_kernel_used_) { | ||
meta->strides = meta->calc_strides(dims); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是非strided_kernel时才需要计算stride信息吗?感觉有些和直觉相反
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,stride kernel会在真正kernel里计算
if (!strided_kernel_used_) { | ||
meta->strides = meta->calc_strides(meta->dims); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set_layout
和set_dims
各处理一次stride是重复计算吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
layout和dims发生改变时,需要重新计算stride。所以两个都加了
paddle/phi/infermeta/unary.cc
Outdated
@@ -18,6 +18,7 @@ limitations under the License. */ | |||
#include <set> | |||
|
|||
#include "gflags/gflags.h" | |||
#include "paddle/fluid/platform/flags.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
phi独立编译后不能再使用fluid的头文件了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,感谢。改成使用#include "paddle/phi/core/flags.h"了
6198828
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
有以下几点文档的问题,之后提pr优化叭~
contiguous, is_contiguous, detach_, Tensor.data, Tensor.grad(赋值操作)
这几个api文档没找着,是没暴露嘛?- 请补充一下API相应的中文文档
PR types
New features
PR changes
Others
Description
Pcard-71699
添加Stride机制,主要工作包括:
目前stride策略默认关闭,待测试稳定后设为默认打开。开关为:FLAGS_use_stride_kernel
FLAGS_use_stride_kernel为1时如下API将支持stride:slice, strided_slice, index_select, split, chunk, unbind, as_real, detach, transpose, t, moveaxis, moveaxis, moveaxis, T, reshape, squeeze, unsqueeze, squeeze_, unsqueeze_, flatten, diagonal, imag, real
此外动态图下新增了如下API:view, view_as, as_strided, unfold, contiguous, is_contiguous, detach_, Tensor.data, Tensor.grad(赋值操作)
目前只有动态图支持stride,静态图尚不支持。这可能会发生动静不一致的情况。如果动转静时发生动静不一致,框架将报错处理。