diff --git a/paddle/fluid/eager/to_static/run_program_op_func.h b/paddle/fluid/eager/to_static/run_program_op_func.h index 8f6e6f4028c1d..289073095cc4f 100644 --- a/paddle/fluid/eager/to_static/run_program_op_func.h +++ b/paddle/fluid/eager/to_static/run_program_op_func.h @@ -90,6 +90,26 @@ static std::vector filter_unused_input_var_in_backward( return filter_x; } +static std::vector Trans2ContiguousTensors( + const std::vector& tensors) { + std::vector res; + for (const auto& t : tensors) { + if (t.is_initialized() && t.is_dense_tensor() && + !std::dynamic_pointer_cast(t.impl()) + ->meta() + .is_contiguous()) { + res.emplace_back( + std::make_shared( + std::move(paddle::experimental::Trans2Contiguous( + *(std::dynamic_pointer_cast(t.impl()))))), + t.mutable_autograd_meta()); + } else { + res.emplace_back(t); + } + } + return res; +} + inline void run_program_ad_func( const std::vector& x, const std::vector& params, @@ -112,9 +132,12 @@ inline void run_program_ad_func( VLOG(2) << "start run run_program with require_any_grad = " << require_any_grad; + auto x_tmp = Trans2ContiguousTensors(x); + auto params_tmp = Trans2ContiguousTensors(params); // Call forward function // if require_any_grad is False, don't save any middle vars. - RunProgramAPI(x, params, out, step_scope, dout, require_any_grad, attrs); + RunProgramAPI( + x_tmp, params_tmp, out, step_scope, dout, require_any_grad, attrs); VLOG(2) << "start run run_program grad"; if (require_any_grad) { @@ -133,14 +156,14 @@ inline void run_program_ad_func( auto* backward_global_block = PADDLE_GET_CONST( paddle::framework::BlockDesc*, attrs.at("backward_global_block")); // Clear unused x vars - auto filter_x = - filter_unused_input_var_in_backward(x, x_names, backward_global_block); + auto filter_x = filter_unused_input_var_in_backward( + x_tmp, x_names, backward_global_block); // Set TensorWrappers grad_node->SetFwdX(filter_x); // Clear unused out vars clear_unused_out_var_in_backward(out, backward_global_block, step_scope[0]); - grad_node->SetFwdParams(params); + grad_node->SetFwdParams(params_tmp); grad_node->SetStepScope(step_scope); // Set Grad out rank as same as fwd input and set stop gradient to bwd diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index c243f2c7b04d8..3c53edbba8a9b 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -24,6 +24,7 @@ #include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/profiler/event_tracing.h" +#include "paddle/phi/api/lib/data_transform.h" #include "paddle/pir/core/program.h" #include "paddle/pir/core/value.h" @@ -32,6 +33,25 @@ PHI_DECLARE_bool(enable_new_ir_in_executor); namespace details { using Tensor = paddle::Tensor; +static void Trans2ContiguousTensorsInplace( + const std::vector &tensors) { + std::vector res; + for (auto &t : tensors) { + if (t.is_initialized() && t.is_dense_tensor() && + !std::dynamic_pointer_cast(t.impl()) + ->meta() + .is_contiguous()) { + auto tmp = paddle::experimental::Trans2Contiguous( + *(std::dynamic_pointer_cast(t.impl()))); + auto holder = tmp.MoveMemoryHolder(); + std::dynamic_pointer_cast(t.impl())->ResetHolder( + holder); + std::dynamic_pointer_cast(t.impl())->set_meta( + tmp.meta()); + } + } +} + static std::vector DereferenceTensors( const std::vector &tensor_ptr) { std::vector res; @@ -544,6 +564,8 @@ inline void RunProgramGradAPI( paddle::framework::BlockDesc *, attrs.at("backward_global_block")); auto *backward_program = backward_global_block->Program(); + details::Trans2ContiguousTensorsInplace(out_grad); + auto out_grad_names = details::GetTensorsName(out_grad); auto &interpretercore_info_cache = paddle::framework::InterpreterCoreInfoCache::Instance();