Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run program contiguous tensor first #57431

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions paddle/fluid/eager/to_static/run_program_op_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,26 @@ static std::vector<paddle::Tensor> filter_unused_input_var_in_backward(
return filter_x;
}

static std::vector<paddle::Tensor> Trans2ContiguousTensors(
const std::vector<paddle::Tensor>& tensors) {
std::vector<paddle::Tensor> res;
for (const auto& t : tensors) {
if (t.is_initialized() && t.is_dense_tensor() &&
!std::dynamic_pointer_cast<phi::DenseTensor>(t.impl())
->meta()
.is_contiguous()) {
res.emplace_back(
std::make_shared<phi::DenseTensor>(
std::move(paddle::experimental::Trans2Contiguous(
*(std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()))))),
t.mutable_autograd_meta());
} else {
res.emplace_back(t);
}
}
return res;
}

inline void run_program_ad_func(
const std::vector<paddle::Tensor>& x,
const std::vector<paddle::Tensor>& params,
Expand All @@ -112,9 +132,12 @@ inline void run_program_ad_func(

VLOG(2) << "start run run_program with require_any_grad = "
<< require_any_grad;
auto x_tmp = Trans2ContiguousTensors(x);
auto params_tmp = Trans2ContiguousTensors(params);
// Call forward function
// if require_any_grad is False, don't save any middle vars.
RunProgramAPI(x, params, out, step_scope, dout, require_any_grad, attrs);
RunProgramAPI(
x_tmp, params_tmp, out, step_scope, dout, require_any_grad, attrs);
VLOG(2) << "start run run_program grad";

if (require_any_grad) {
Expand All @@ -133,14 +156,14 @@ inline void run_program_ad_func(
auto* backward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc*, attrs.at("backward_global_block"));
// Clear unused x vars
auto filter_x =
filter_unused_input_var_in_backward(x, x_names, backward_global_block);
auto filter_x = filter_unused_input_var_in_backward(
x_tmp, x_names, backward_global_block);
// Set TensorWrappers
grad_node->SetFwdX(filter_x);
// Clear unused out vars
clear_unused_out_var_in_backward(out, backward_global_block, step_scope[0]);

grad_node->SetFwdParams(params);
grad_node->SetFwdParams(params_tmp);
grad_node->SetStepScope(step_scope);

// Set Grad out rank as same as fwd input and set stop gradient to bwd
Expand Down
22 changes: 22 additions & 0 deletions paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/pir/core/program.h"
#include "paddle/pir/core/value.h"

Expand All @@ -32,6 +33,25 @@ PHI_DECLARE_bool(enable_new_ir_in_executor);
namespace details {
using Tensor = paddle::Tensor;

static void Trans2ContiguousTensorsInplace(
const std::vector<paddle::Tensor> &tensors) {
std::vector<Tensor> res;
for (auto &t : tensors) {
if (t.is_initialized() && t.is_dense_tensor() &&
!std::dynamic_pointer_cast<phi::DenseTensor>(t.impl())
->meta()
.is_contiguous()) {
auto tmp = paddle::experimental::Trans2Contiguous(
*(std::dynamic_pointer_cast<phi::DenseTensor>(t.impl())));
auto holder = tmp.MoveMemoryHolder();
std::dynamic_pointer_cast<phi::DenseTensor>(t.impl())->ResetHolder(
holder);
std::dynamic_pointer_cast<phi::DenseTensor>(t.impl())->set_meta(
tmp.meta());
}
}
}

static std::vector<Tensor> DereferenceTensors(
const std::vector<Tensor *> &tensor_ptr) {
std::vector<Tensor> res;
Expand Down Expand Up @@ -544,6 +564,8 @@ inline void RunProgramGradAPI(
paddle::framework::BlockDesc *, attrs.at("backward_global_block"));
auto *backward_program = backward_global_block->Program();

details::Trans2ContiguousTensorsInplace(out_grad);

auto out_grad_names = details::GetTensorsName(out_grad);
auto &interpretercore_info_cache =
paddle::framework::InterpreterCoreInfoCache::Instance();
Expand Down