-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NewIR] Support Ir run program node #56791
Changes from all commits
2430d99
096cd0d
ba9fb53
c0d9558
da6fd48
6eb1893
68d1613
1a56b7d
d84506c
bb36572
840f93e
00ccd45
2c8df4a
dac0b24
a7cba09
9e05e8e
464b253
2c7419a
903a663
be3d209
6cecc00
97c68bb
0f61410
5550797
4a0815a
150e8b5
9d033ab
4d65f37
c25e0af
a195b67
b11b829
af98cb8
0a97e32
751829e
f9890fb
78c80f9
0f42e6f
4de0215
b912fac
7575338
18c08f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -98,6 +98,7 @@ inline void run_program_ad_func( | |||||
std::vector<paddle::Tensor*>& dout, // NOLINT | ||||||
const paddle::framework::AttributeMap& attrs) { | ||||||
// Prepare Autograd Meta | ||||||
VLOG(2) << "start run run_program ad function."; | ||||||
auto deref_out = details::DereferenceTensors(out); | ||||||
std::vector<egr::AutogradMeta*> p_autograd_x = | ||||||
egr::EagerUtils::nullable_autograd_meta(x); | ||||||
|
@@ -174,3 +175,107 @@ inline void run_program_ad_func( | |||||
egr::EagerUtils::SetHistory(&p_autograd_outs, grad_node); | ||||||
} | ||||||
} | ||||||
|
||||||
inline void newir_run_program_ad_func( | ||||||
const std::vector<paddle::Tensor>& x, | ||||||
const std::vector<paddle::Tensor>& params, | ||||||
std::vector<paddle::Tensor*>& out, // NOLINT | ||||||
std::vector<paddle::framework::Scope*>& step_scope, // NOLINT | ||||||
std::vector<paddle::Tensor*>& dout, // NOLINT | ||||||
const paddle::framework::AttributeMap& attrs) { | ||||||
// Prepare Autograd Meta | ||||||
VLOG(2) << "start run newir run_program ad function."; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
auto deref_out = details::DereferenceTensors(out); | ||||||
std::vector<egr::AutogradMeta*> p_autograd_x = | ||||||
egr::EagerUtils::nullable_autograd_meta(x); | ||||||
std::vector<egr::AutogradMeta*> p_autograd_params = | ||||||
egr::EagerUtils::nullable_autograd_meta(params); | ||||||
std::vector<egr::AutogradMeta*> p_autograd_outs = | ||||||
egr::EagerUtils::nullable_autograd_meta(deref_out); | ||||||
|
||||||
bool trace_backward = egr::Controller::Instance().HasGrad(); | ||||||
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad( | ||||||
trace_backward, &p_autograd_x, &p_autograd_params); | ||||||
|
||||||
// Create Middle Output for GradNode. | ||||||
auto middle_size = | ||||||
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fm")).size(); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里后续需要解耦run_program_op maker |
||||||
auto output_size = | ||||||
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fo")).size(); | ||||||
auto middles = std::vector<paddle::Tensor*>(); | ||||||
std::shared_ptr<NewIRGradNodeRunProgram> grad_node; | ||||||
VLOG(2) << "start run run_program with require_any_grad = " | ||||||
<< require_any_grad; | ||||||
|
||||||
if (require_any_grad) { | ||||||
// Create GradOpNode (1 means [out_grad], 2 means [x_grad, paramx_grad]) | ||||||
grad_node = std::make_shared<NewIRGradNodeRunProgram>(1, 2); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
这里或许可以使用namespace 来隔离,pir::GradNodeRunProgram |
||||||
grad_node->GetMiddle().resize(middle_size); | ||||||
grad_node->GetOutputs().resize(output_size); | ||||||
for (size_t i = 0; i < middle_size; ++i) { | ||||||
grad_node->GetMiddle()[i] = | ||||||
paddle::Tensor(std::make_shared<phi::DenseTensor>()); | ||||||
middles.push_back(&grad_node->GetMiddle()[i]); | ||||||
} | ||||||
for (size_t i = 0; i < output_size; ++i) { | ||||||
grad_node->GetOutputs()[i] = *out[i]; | ||||||
} | ||||||
} | ||||||
|
||||||
// Call forward function | ||||||
// if require_any_grad is False, don't save any middle vars. | ||||||
NewIRRunProgramAPI( | ||||||
x, params, out, middles, step_scope, dout, require_any_grad, attrs); | ||||||
if (require_any_grad) { | ||||||
// auto x_names = | ||||||
// PADDLE_GET_CONST(std::vector<std::string>, attrs.at("x_names")); | ||||||
|
||||||
egr::EagerUtils::PassStopGradient(false, &p_autograd_outs); | ||||||
|
||||||
// Set Attributes | ||||||
grad_node->SetAttrMap(attrs); | ||||||
|
||||||
// auto* forward_global_block = PADDLE_GET_CONST( | ||||||
// paddle::framework::BlockDesc*, attrs.at("forward_global_block")); | ||||||
// auto* backward_global_block = PADDLE_GET_CONST( | ||||||
// paddle::framework::BlockDesc*, attrs.at("backward_global_block")); | ||||||
// Clear unused x vars | ||||||
// auto filter_x = | ||||||
// filter_unused_input_var_in_backward(x, x_names, backward_global_block); | ||||||
// Set TensorWrappers | ||||||
grad_node->SetFwdX(x); | ||||||
// Clear unused out vars | ||||||
// clear_unused_out_var_in_backward(out, backward_global_block, | ||||||
// step_scope[0]); | ||||||
|
||||||
grad_node->SetFwdParams(params); | ||||||
grad_node->SetStepScope(step_scope); // just for set useable. | ||||||
|
||||||
// Set Grad out rank as same as fwd input and set stop gradient to bwd | ||||||
// NOTE(@xiongkun): Not every tensor in x(list of tensor) is required | ||||||
// gradient. for example: x[1] is not used for output, the x[1] is ignored. | ||||||
|
||||||
// TODO(@xiongkun): rewrite by new ir representation. | ||||||
std::vector<const paddle::Tensor*> x_require_grad; | ||||||
for (size_t i = 0; i < x.size(); ++i) { | ||||||
x_require_grad.push_back(&x[i]); | ||||||
} | ||||||
|
||||||
grad_node->SetGradOutMeta(x_require_grad, /*slot id*/ 0); | ||||||
grad_node->SetGradOutMeta(params, /*slot id*/ 1); | ||||||
|
||||||
// VLOG(2) << "clear_no_grad_edges."; | ||||||
// clear_no_grad_edges_with_partial_block(params, | ||||||
// forward_global_block, | ||||||
// backward_global_block, | ||||||
// grad_node.get(), | ||||||
// [>slot id<] 1); | ||||||
|
||||||
grad_node->SetGradInMeta(deref_out, 0); | ||||||
|
||||||
egr::EagerUtils::SetOutRankWithSlot(&p_autograd_outs, 0); | ||||||
|
||||||
// Set History for output set current Grad Node for | ||||||
egr::EagerUtils::SetHistory(&p_autograd_outs, grad_node); | ||||||
} | ||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.