-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PIR][AutoParallel] Support 1F1B/FThenB with PIR #58459
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
❌ The PR is not created using PR's template. You can refer to this Demo. |
feed_name, | ||
numel_size, | ||
micro_batch_num)); | ||
int64_t split_size = (numel_size + micro_batch_num - 1) / micro_batch_num; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当numel_size % micro_batch_num = 0
时,(numel_size + micro_batch_num - 1) / micro_batch_num
与 numel_size / micro_batch_num
是等价的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix it in the next PR
VLOG(4) << "Split feed data:" << feed_name << ", dims:(" | ||
<< feed_tensor.dims() << "), micro_batch_num:" << micro_batch_num; | ||
for (int64_t j = 0; j < micro_batch_num; ++j) { | ||
(*out)[j].resize(i + 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在内层循环中每次通过resize
动态扩展vector是低效的,可以直接给(*out)[j]分配feed_names.size()长度的容量
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix it in the next PR
@@ -278,6 +316,19 @@ def set_skip_gc_vars(num_micro_batches, type_to_program, jobs): | |||
job.set_skip_gc_vars(skip_gc_vars) | |||
suffixed_required_vars[micro_batch_id] |= required_vars | |||
|
|||
if get_flags("FLAGS_enable_new_ir_in_executor")[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段逻辑看起来与skip_gc_vars
并无关联,为何要写在set_skip_gc_vars
函数里?把job_types
和sub_programs
打包成dict
的代码杂糅进set_skip_gc_vars
也是不建议的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix it in the next PR
@@ -225,7 +225,44 @@ def var_can_be_deleted(var_name, block): | |||
return var is not None and not var.persistable | |||
|
|||
|
|||
def set_skip_gc_vars(num_micro_batches, type_to_program, jobs): | |||
def prepare_ir_program(cur_prog, next_prog): | |||
set_output_names = set() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set_output_names
的set
在阅读上是有二义性的,既可以表示集合
,也可以表示设置
。如非必须,不建议在变量命名中带入list
、array
、set
等基本类型信息。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix it in the next PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall. Some coding problems should be improved in the future.
* [PIR][AutoParallel] Support 1F1B/FThenB with PIR * fix splitfeed * fix include * rm fetch * fix unittest * fix scope error * program interpreter use local scope to avoid var conflict * fix ut --------- Co-authored-by: zhaoyingli <zhaoyingli@baidu.com>
* [PIR][AutoParallel] Support 1F1B/FThenB with PIR * fix splitfeed * fix include * rm fetch * fix unittest * fix scope error * program interpreter use local scope to avoid var conflict * fix ut --------- Co-authored-by: zhaoyingli <zhaoyingli@baidu.com>
PR types
New features
PR changes
Others
Description
PCard-71568
From #58405
SplitFeedTensor
method for pirshadow_output
op for pir to hold vardata
op to construct input for pir