Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] test_ifelse ut add pir test branch #58163

Merged
merged 20 commits into from
Oct 19, 2023

Conversation

zhangbo9674
Copy link
Contributor

@zhangbo9674 zhangbo9674 commented Oct 17, 2023

PR types

Others

PR changes

Others

Description

随着 PIR 完成对 IF 前向执行的支持,将 test_ifelse 单测添加 PIR 的测试分支
TODO:本 PR 尚有5个测试 case 待开启 PIR 的测试分支,具体包括:

  • TestDygraphNestedIfElse、TestDygraphIfElseNet: if 嵌套 if 的场景下,子 block 找不到 父 block 的变量,分析中
  • TestDygraphIfTensor、TestDy2StIfElseRetInt3:cond_block 算子后会插入一些 fill_constant,导致翻译存在问题,分析中
  • TestDy2StIfElseBackward:if_grad 的执行待支持

Pcard-67164

@paddle-bot
Copy link

paddle-bot bot commented Oct 17, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -259,6 +259,9 @@ pir::OpInfo OpTranscriber::LoopkUpOpInfo(pir::IrContext* ctx,
continue;
}
VarDesc* var = op_desc.Block()->FindVarRecursive(legacy_input_vars[0]);
IR_ENFORCE(var != nullptr,
"Can't find var recursivelly from current block.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Can't find var recursivelly from current block.");
"Can't find var recursively from current block.");

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return self._run_dygraph(to_static=True)

def _run_dygraph(self, to_static=False):
with base.dygraph.guard(place):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里新增的代码可以按照Paddle最新的API来写。比如:

  1. base.dygraph.guard → paddle.set_device()
  2. to_variable → to_tensor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

追加commit或者新PR都可以

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def _run_dygraph(self, to_static=False):
with base.dygraph.guard(place):
x_v = base.dygraph.to_variable(self.x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@zhangbo9674 zhangbo9674 merged commit 5e3974e into PaddlePaddle:develop Oct 19, 2023
28 checks passed
hitywt pushed a commit to hitywt/Paddle that referenced this pull request Oct 24, 2023
* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* add

* fix

* fix

* fix

* fix

* refine code
jiahy0825 pushed a commit to jiahy0825/Paddle that referenced this pull request Oct 26, 2023
* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* add

* fix

* fix

* fix

* fix

* refine code
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* add

* fix

* fix

* fix

* fix

* refine code
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants