-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.52】 为 Paddle 新增 squeeze 和 unsqueeze 的 spmd 切分推导规则 #57877
【Hackathon 5th No.52】 为 Paddle 新增 squeeze 和 unsqueeze 的 spmd 切分推导规则 #57877
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@Ligoml 梦老师您好,能不能先让审核老师看一下这个PR的现存问题呀?我遇到瓶颈,进行不下去了 /大哭 |
@WintersMontagne10335 回复你的两个问题:
|
Sorry to inform you that 6c2f23f's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
@pkuzyc
|
@pkuzyc |
@WintersMontagne10335 单测跑不了是因为上面的报错吗?这个报错看上去是规则没有注册,是不是没有用最新的 develop 编译。reshape 的规则是已经合入了的,重新源码编译、安装 python 包之后应该不会报这个错。 |
找到问题,已经解决了。需要在cmake的时候开启 |
@WintersMontagne10335 看看能不能把 squeeze 和 unsqueeze 分开提?一个 pr 内容有点多 |
@pkuzyc 收到 |
@pkuzyc 老师您好,这个PR也可以review一下。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
输入axis为空的情况我们看一下
void MakeSqueezeDimTransWithAxis(const std::vector<int64_t>& x_shape, | ||
std::vector<int64_t>* out_shape, | ||
const std::vector<int64_t>& axis, | ||
std::vector<DimTrans*>* trans) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该可以写得更简单些,先算出输出的size,然后遍历输出的维度,如果某一维在 axis 里且 shape 是 1就用Singleton,否则用 InputDim。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
bool SqueezeCompare(const int64_t& a, const int64_t& b) { return a > b; } | ||
|
||
bool SqueezeReverseCompare(const int64_t& a, const int64_t& b) { return a < b; } | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要自己定义比较函数,sort对vector排序的时候默认从小到大排,标准库有自带的greater函数从大到小排
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据老师的 review
,简化了 MakeSqueezeDimTransWithAxis
的实现逻辑,已不需要排序
axis_copy[i] += x_ndim; | ||
} | ||
} | ||
std::sort(axis_copy.begin(), axis_copy.end(), SqueezeCompare); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉自己定义的比较函数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
axis_copy[i] += x_ndim; | ||
} | ||
} | ||
std::sort(axis_copy.begin(), axis_copy.end(), SqueezeReverseCompare); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用标准库自带的比较函数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) | ||
self.attrs = OrderedDict() | ||
|
||
def test_squeeze_infer_forward(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
补一个 '1' 的维度被切的单测,看看推导完有没有变成-1,例如:
[1, 8, 1, 16] --> [8, 1, 16]
[-1, 0, 1, -1] --> [-1, 0, -1, -1], [0, -1, -1]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done。本地看了,原先的实现是有问题的,已修改,目前推导完变成-1了。
@pkuzyc 老师您好,unsqueeze CI已过,可以review了。 |
Sorry to inform you that 9f95328's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
@WintersMontagne10335 unsqueeze合入了,可以更新squeeze 了 |
@WintersMontagne10335 最新一个pr把推导规则里用指针的地方改成智能指针了,需要按最新的改下:#59101 输入空list的情况我这两天再看看,要是支持起来比较复杂这个pr可以先合 |
收到!我今天改一下哈。 |
@pkuzyc Done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一些小地方再修改一下,之前 slice (#57866) 里的一些没有修改的看看也一起改了吧
if (x_shape[i] == 1) { | ||
auto it = find(axis.begin(), axis.end(), i); | ||
if (it == axis.end()) { | ||
trans->emplace_back(new Singleton()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里改成 make_shared
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
if (x_shape[i] != 1) { | ||
trans->emplace_back(std::make_shared<InputDim>(j++)); | ||
} else { | ||
trans->emplace_back(new Singleton()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成 make_shared
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
for (int64_t i = 0, j = 0, n = static_cast<int64_t>(x_shape.size()); i < n; | ||
i++) { | ||
if (x_shape[i] == 1) { | ||
trans->emplace_back(new Singleton()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成 make_shared
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
VLOG(4) << "Transformation from input to output:"; | ||
for (int64_t i = 0, n = static_cast<int64_t>(trans.size()); i < n; i++) { | ||
std::shared_ptr<DimTrans> t = trans[i]; | ||
VLOG(4) << "\tOut axis[" << i << "]: " << t->to_string(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接输出 trans[i]->to_string() 就行,不用多一个赋值
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
VLOG(4) << "Transformation from output to input:"; | ||
for (int64_t i = 0, n = trans.size(); i < n; i++) { | ||
std::shared_ptr<DimTrans> t = trans[i]; | ||
VLOG(4) << "\tX axis[" << i << "]: " << t->to_string(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,直接用 trans[i]->to_string() 就行
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
收到! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…addlePaddle#57877) * Add spmd segmentation and derivation rules for squeeze to Paddle * Add spmd segmentation derivation rule for unsqueeze to Paddle * fix bugs * fix bugs * fix bugs * fix bugs * Add unit test code * modify squeeze.cc and CMakeLists.txt * write separate rules * fix bugs * fix bugs * fix bugs * remove unsqueeze spmd rule * modified: test/auto_parallel/spmd_rules/test_squeeze_rule.py * re-run CI * fix bugs * modify pointer to smart pointer * fix bugs * fix bugs
PR types
Others
PR changes
Others
Description
为 Paddle 新增 squeeze 和 unsqueeze 的 spmd 切分推导规则
#57262
https://github.com/PaddlePaddle/community/blob/master/pfcc/paddle-code-reading/auto_parallel/spmd_rules.md
切分推导规则的输入参数,不支持空list
【Hackathon 5th No.52】 为 Paddle 新增 unsqueeze 的 spmd 切分推导规则 -part #58296