diff --git a/paddle/fluid/eager/custom_operator/CMakeLists.txt b/paddle/fluid/eager/custom_operator/CMakeLists.txt index a2648d3e32556..a74ba2dc8c628 100644 --- a/paddle/fluid/eager/custom_operator/CMakeLists.txt +++ b/paddle/fluid/eager/custom_operator/CMakeLists.txt @@ -1,4 +1,9 @@ cc_library( custom_operator_node SRCS custom_operator_node.cc + DEPS phi grad_node_info custom_operator utils custom_operator_utils) + +cc_library( + custom_operator_utils + SRCS custom_operator_utils.cc DEPS phi grad_node_info custom_operator utils) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_node.cc b/paddle/fluid/eager/custom_operator/custom_operator_node.cc index 5643c0e69391f..9b6318c7a43ed 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_node.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_node.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/eager/custom_operator/custom_operator_node.h" +#include "paddle/fluid/eager/custom_operator/custom_operator_utils.h" #include "paddle/fluid/framework/custom_operator.h" #include "paddle/fluid/framework/custom_operator_utils.h" #include "paddle/fluid/platform/profiler/event_tracing.h" @@ -172,8 +173,6 @@ RunCustomOpNode::operator()(paddle::small_vector, paddle::OpMetaInfoHelper::GetInputs(vec_map[1]); const auto& grad_outputs_names = paddle::OpMetaInfoHelper::GetOutputs(vec_map[1]); - const auto& grad_inplace_map = - paddle::OpMetaInfoHelper::GetInplaceMap(vec_map[1]); const auto& map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_); @@ -251,11 +250,12 @@ RunCustomOpNode::operator()(paddle::small_vector, } VLOG(7) << "Run Kernel of Grad Custom Op: " << op_type_ << "_grad"; - // handle inplace map - ctx.UpdatePlainOutputs( - grad_inputs_name, grad_outputs_names, grad_inplace_map); - (*paddle::OpMetaInfoHelper::GetKernelFn(vec_map[1]))(&ctx); - ctx.AssignInplaceOutputs(); + run_custom_op_impl(vec_map[1], false, false, ctx); + + for (size_t i = 0; i < ctx.OutputRange().size(); ++i) { + auto output_pair = ctx.OutputRangeAt(i); + outs[i] = ctx.OutputsBetween(output_pair.first, output_pair.second); + } // handle optional None output when construct backward graph for (size_t i = 0; i < ctx.OutputRange().size(); i++) { @@ -264,7 +264,9 @@ RunCustomOpNode::operator()(paddle::small_vector, ctx.MutableOutputAt(ctx.OutputRangeAt(i).first); if (!out_tensor->initialized()) { PADDLE_ENFORCE( - paddle::framework::detail::IsOptionalVar(grad_outputs_names.at(i)), + paddle::framework::detail::IsOptionalVar( + grad_outputs_names.at(i)) || + out_tensor->is_dist_tensor(), phi::errors::InvalidArgument( "Custom grad operator's %d-th output is not initialized. " "Please check your implementation again. If you are " @@ -386,8 +388,6 @@ RunCustomOpDoubleGradNode::operator()( paddle::OpMetaInfoHelper::GetInputs(vec_map[2]); const auto& grad_outputs_names = paddle::OpMetaInfoHelper::GetOutputs(vec_map[2]); - const auto& grad_inplace_map = - paddle::OpMetaInfoHelper::GetInplaceMap(vec_map[2]); const auto& map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_); @@ -451,11 +451,12 @@ RunCustomOpDoubleGradNode::operator()( } VLOG(7) << "Run Kernel of Grad Custom Op: " << op_type_ << "_grad_grad"; - // handle inplace map - ctx.UpdatePlainOutputs( - grad_inputs_name, grad_outputs_names, grad_inplace_map); - (*paddle::OpMetaInfoHelper::GetKernelFn(vec_map[2]))(&ctx); - ctx.AssignInplaceOutputs(); + run_custom_op_impl(vec_map[2], false, true, ctx); + + for (size_t i = 0; i < ctx.OutputRange().size(); ++i) { + auto output_pair = ctx.OutputRangeAt(i); + outs[i] = ctx.OutputsBetween(output_pair.first, output_pair.second); + } return outs; } diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc new file mode 100644 index 0000000000000..7985ef92285d0 --- /dev/null +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc @@ -0,0 +1,709 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/eager/custom_operator/custom_operator_utils.h" + +#include "paddle/fluid/eager/autograd_meta.h" +#include "paddle/fluid/framework/custom_operator.h" +#include "paddle/fluid/framework/custom_operator_utils.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/api/lib/data_transform.h" +#include "paddle/phi/api/lib/kernel_dispatch.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/flags.h" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/api/lib/api_gen_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#endif + +namespace egr { + +using Tensor = paddle::Tensor; + +static std::vector> RunDefaultInferShapeFunc( + const paddle::CustomOpKernelContext& ctx, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + std::vector> result; + if (inplace_map.empty()) { // general case, assure single input and output + PADDLE_ENFORCE_EQ( + inputs.size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple inputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferShapeFn. " + "At this time, the input shape will be directly set to " + "the output shape.\n" + "Please set the InferShapeFn of custom " + "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); + PADDLE_ENFORCE_EQ( + outputs.size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple outputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferShapeFn. " + "At this time, the input shape will be directly set to " + "the output shape.\n" + "Please set the InferShapeFn of custom " + "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); + + VLOG(3) << "Custom Operator: Default InferShape - share ddim."; + result.push_back({ctx.InputAt(0).dims()}); + } else { // inplace case + PADDLE_ENFORCE_EQ( + inplace_map.size(), + outputs.size(), + phi::errors::Unavailable( + "Your custom operator uses `SetInplaceMap` without setting the " + "InferShapeFn. However, `Outputs` size = %d does not match the " + "`InplaceMap` size = %d. Please check `SetInplaceMap` again or set " + "the InferShapeFn of custom operator by " + "`.SetInferShapeFn(PD_INFER_SHAPE(...)`)", + outputs.size(), + inplace_map.size())); + for (size_t i = 0; i < ctx.InputRange().size(); ++i) { + if (paddle::framework::detail::IsDuplicableVar(inputs[i])) { + std::vector shapes; + auto duplicable_input_pair = ctx.InputRangeAt(i); + for (size_t j = duplicable_input_pair.first; + j < duplicable_input_pair.second; + j++) { + shapes.push_back(ctx.InputAt(j).dims()); + } + result.emplace_back(std::move(shapes)); + } else { + auto duplicable_input_pair = ctx.InputRangeAt(i); + result.push_back({ctx.InputAt(duplicable_input_pair.first).dims()}); + } + } + } + return result; +} + +static std::vector> RunDefaultGradInferShapeFunc( + const paddle::CustomOpKernelContext& ctx, + const std::vector& grad_op_inputs, + const std::vector& grad_op_outputs, + bool is_double_grad) { + std::vector> result; + // 1. if forward input exists, gradient's shape is same with forward + // input + // default + // [Suitable for most situations] + // 2. if forward input not exists, and only contains one grad input and + // output, + // use grad input shape as grad output shape + // [Suitable for the situation that forward input is not used as + // backward input] + for (auto& out_name : grad_op_outputs) { + auto fwd_name = paddle::framework::detail::NoGrad(out_name, is_double_grad); + if (paddle::framework::detail::IsDuplicableVar(fwd_name)) { + // Duplicable forward var must as backward input + auto iter = + std::find(grad_op_inputs.begin(), grad_op_inputs.end(), fwd_name); + PADDLE_ENFORCE_NE( + iter, + grad_op_inputs.end(), + phi::errors::NotFound("Custom grad operator should have the forward " + "input(%s) as backward input", + fwd_name)); + auto pair = ctx.InputRangeAt(iter - grad_op_inputs.begin()); + std::vector tmp; + for (size_t i = pair.first; i < pair.second; ++i) { + tmp.emplace_back(ctx.InputAt(i).dims()); + } + result.emplace_back(std::move(tmp)); + } else { + if (grad_op_inputs.size() == grad_op_outputs.size()) { + result.push_back({ctx.InputAt(0).dims()}); + } else { + auto iter = + std::find(grad_op_inputs.begin(), grad_op_inputs.end(), fwd_name); + PADDLE_ENFORCE_NE( + iter, + grad_op_inputs.end(), + phi::errors::NotFound("Custom grad operator should have the " + "forward input(%s) as backward input", + fwd_name)); + auto pair = ctx.InputRangeAt(iter - grad_op_inputs.begin()); + result.push_back({ctx.InputAt(pair.first).dims()}); + } + } + } + return result; +} + +static std::vector> RunInferShapeFunc( + const paddle::CustomOpKernelContext& ctx, + const paddle::InferShapeFunc& func, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + std::vector> result; + std::vector> input_shapes; + std::vector>> vec_input_shapes; + + VLOG(3) << "Custom Operator: InferShape - get input ddim."; + for (size_t i = 0; i < ctx.InputRange().size(); ++i) { + const auto& input_pair = ctx.InputRangeAt(i); + if (input_pair.first == input_pair.second - 1) { + input_shapes.emplace_back( + std::move(ctx.InputAt(input_pair.first).shape())); + } else { + std::vector> shapes; + for (size_t j = input_pair.first; j < input_pair.second; j++) { + shapes.push_back(std::move(ctx.InputAt(j).shape())); + } + vec_input_shapes.emplace_back(std::move(shapes)); + } + } + + VLOG(3) << "Custom Operator: InferShape - calc output ddim."; + auto output_shapes = func(input_shapes, vec_input_shapes, ctx.Attrs()); + if (inplace_map.empty()) { + PADDLE_ENFORCE_EQ(outputs.size(), + output_shapes.size(), + phi::errors::InvalidArgument( + "Your custom operator has set the InferShapeFn. " + "However, `Outputs` size = %d does not match the " + "returned vector size of InferShapeFn = %d. Please " + "check InferShapeFn again.", + outputs.size(), + output_shapes.size())); + } else { + PADDLE_ENFORCE_EQ( + outputs.size(), + output_shapes.size() + inplace_map.size(), + phi::errors::InvalidArgument( + "Your custom operator uses `SetInplaceMap` and sets the " + "InferShapeFn. However, `Outputs` size = %d does not match the " + "`InplaceMap size + InferShapeFn output size` = %d. Please check " + "InplaceMap and InferShapeFn again", + outputs.size(), + output_shapes.size() + inplace_map.size())); + } + + VLOG(3) + << "Custom Operator: InferShape - set output ddim: inplace_map.size() = " + << inplace_map.size() + << ", output_shapes.size() = " << output_shapes.size(); + size_t output_shape_idx = 0; + auto inplace_reverse_map = ctx.GetInplaceReverseIndexMap(); + for (size_t i = 0; i < outputs.size(); ++i) { + if (paddle::framework::detail::IsDuplicableVar(outputs[i])) { + PADDLE_ENFORCE( + inplace_reverse_map.find(i) != inplace_reverse_map.end(), + phi::errors::InvalidArgument( + "Custom operator only supports `paddle::Vec(...)` inputs and " + "cannot support `paddle::Vec(...)` output without setting " + "InplaceMap. If you have to use `paddle::Vec(...)` output, " + "please indicate it by setting InplaceMap manully.")); + std::vector shapes; + auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); + for (size_t j = duplicable_input_pair.first; + j < duplicable_input_pair.second; + j++) { + shapes.push_back(ctx.InputAt(j).dims()); + } + result.emplace_back(std::move(shapes)); + } else { + if (inplace_reverse_map.find(i) != inplace_reverse_map.end()) { + auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); + result.push_back({ctx.InputAt(duplicable_input_pair.first).dims()}); + } else { + result.push_back({phi::make_ddim(output_shapes[output_shape_idx++])}); + } + } + } + return result; +} + +static std::vector> RunDefaultInferDtypeFunc( + const paddle::CustomOpKernelContext& ctx, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + std::vector> result; + if (inplace_map.empty()) { // general case, assure single input and output + PADDLE_ENFORCE_EQ( + inputs.size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple inputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferDtypeFn. " + "At this time, the input dtype will be directly set to " + "the output dtype.\n" + "Please set the InferDtypeFn of custom " + "operator by `.SetInferDtypeFn(PD_INFER_DTYPE(...))`")); + PADDLE_ENFORCE_EQ( + outputs.size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple outputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferDtypeFn. " + "At this time, the input dtype will be directly set to " + "the output dtype.\n" + "Please set the InferDtypeFn of custom " + "operator by `.SetInferDtypeFn(PD_INFER_DTYPE(...))`")); + + VLOG(3) << "Custom Operator: InferDtype - share dtype."; + result.push_back({ctx.InputAt(0).dtype()}); + } else { // inplace case + PADDLE_ENFORCE_EQ( + inplace_map.size(), + outputs.size(), + phi::errors::Unavailable( + "Your custom operator uses `SetInplaceMap` without setting the " + "InferDtypeFn. However, `Outputs` size = %d does not match the " + "`InplaceMap` size = %d. Please check `SetInplaceMap` again or set " + "the InferDtypeFn of custom operator by " + "`.SetInferDtypeFn(PD_INFER_DTYPE(...))`", + outputs.size(), + inplace_map.size())); + for (size_t i = 0; i < ctx.InputRange().size(); ++i) { + if (paddle::framework::detail::IsDuplicableVar(inputs[i])) { + std::vector shapes; + auto duplicable_input_pair = ctx.InputRangeAt(i); + for (size_t j = duplicable_input_pair.first; + j < duplicable_input_pair.second; + j++) { + shapes.push_back(ctx.InputAt(j).dtype()); + } + result.emplace_back(std::move(shapes)); + } else { + auto duplicable_input_pair = ctx.InputRangeAt(i); + result.push_back({ctx.InputAt(duplicable_input_pair.first).dtype()}); + } + } + } + return result; +} + +static std::vector> RunDefaultGradInferDtypeFunc( + const paddle::CustomOpKernelContext& ctx, + const std::vector& grad_op_inputs, + const std::vector& grad_op_outputs, + bool is_double_grad) { + std::vector> result; + for (auto& out_name : grad_op_outputs) { + auto fwd_name = paddle::framework::detail::NoGrad(out_name, is_double_grad); + if (paddle::framework::detail::IsDuplicableVar(fwd_name)) { + // Duplicable forward var must as backward input + auto iter = + std::find(grad_op_inputs.begin(), grad_op_inputs.end(), fwd_name); + PADDLE_ENFORCE_NE( + iter, + grad_op_inputs.end(), + phi::errors::NotFound("Custom grad operator should have the forward " + "input(%s) as backward input", + fwd_name)); + auto pair = ctx.InputRangeAt(iter - grad_op_inputs.begin()); + std::vector tmp; + for (size_t i = pair.first; i < pair.second; ++i) { + tmp.emplace_back(ctx.InputAt(i).dtype()); + } + result.emplace_back(std::move(tmp)); + } else { + if (grad_op_inputs.size() == grad_op_outputs.size()) { + result.push_back({ctx.InputAt(0).dtype()}); + } else { + auto iter = + std::find(grad_op_inputs.begin(), grad_op_inputs.end(), fwd_name); + PADDLE_ENFORCE_NE( + iter, + grad_op_inputs.end(), + phi::errors::NotFound("Custom grad operator should have the " + "forward input(%s) as backward input", + fwd_name)); + auto pair = ctx.InputRangeAt(iter - grad_op_inputs.begin()); + result.push_back({ctx.InputAt(pair.first).dtype()}); + } + } + } + return result; +} + +static std::vector> RunInferDtypeFunc( + const paddle::CustomOpKernelContext& ctx, + const paddle::InferDtypeFunc& func, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + std::vector> result; + std::vector input_dtypes; + std::vector> vec_input_dtypes; + + VLOG(3) << "Custom Operator: InferDtype - get input dtype."; + for (size_t i = 0; i < ctx.InputRange().size(); ++i) { + const auto& input_pair = ctx.InputRangeAt(i); + if (input_pair.first == input_pair.second - 1) { + input_dtypes.emplace_back( + std::move(ctx.InputAt(input_pair.first).dtype())); + } else { + std::vector dtypes; + for (size_t j = input_pair.first; j < input_pair.second; j++) { + dtypes.emplace_back(ctx.InputAt(j).dtype()); + } + vec_input_dtypes.emplace_back(std::move(dtypes)); + } + } + + VLOG(3) << "Custom Operator: InferDtype - infer output dtype."; + auto output_dtypes = func(input_dtypes, vec_input_dtypes, ctx.Attrs()); + if (inplace_map.empty()) { + PADDLE_ENFORCE_EQ(outputs.size(), + output_dtypes.size(), + phi::errors::InvalidArgument( + "Your custom operator has set the InferDtypeFn. " + "However, `Outputs` size = %d does not match the " + "returned vector size of InferDtypeFn = %d. Please " + "check InferDtypeFn again.", + outputs.size(), + output_dtypes.size())); + } else { + PADDLE_ENFORCE_EQ( + outputs.size(), + output_dtypes.size() + inplace_map.size(), + phi::errors::InvalidArgument( + "Your custom operator uses `SetInplaceMap` and sets the " + "InferDtypeFn. However, `Outputs` size = %d does not match the " + "`InplaceMap size + InferDtypeFn output size` = %d. Please check " + "InplaceMap and InferDtypeFn again", + outputs.size(), + output_dtypes.size() + inplace_map.size())); + } + + VLOG(3) + << "Custom Operator: InferDtype - set output dtype: inplace_map.size() = " + << inplace_map.size() + << ", output_dtypes.size() = " << output_dtypes.size(); + size_t output_dtype_idx = 0; + auto inplace_reverse_map = ctx.GetInplaceReverseIndexMap(); + for (size_t i = 0; i < outputs.size(); ++i) { + if (paddle::framework::detail::IsDuplicableVar(outputs[i])) { + PADDLE_ENFORCE( + inplace_reverse_map.find(i) != inplace_reverse_map.end(), + phi::errors::InvalidArgument( + "Custom operator only supports `paddle::Vec(...)` inputs and " + "cannot support `paddle::Vec(...)` output without setting " + "InplaceMap. If you have to use `paddle::Vec(...)` output, " + "please indicate it by setting InplaceMap manully.")); + std::vector dtypes; + auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); + for (size_t j = duplicable_input_pair.first; + j < duplicable_input_pair.second; + j++) { + dtypes.push_back(ctx.InputAt(j).dtype()); + } + result.emplace_back(std::move(dtypes)); + } else { + if (inplace_reverse_map.find(i) != inplace_reverse_map.end()) { + auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); + result.push_back({ctx.InputAt(duplicable_input_pair.first).dtype()}); + } else { + result.push_back({output_dtypes[output_dtype_idx++]}); + } + } + } + return result; +} + +#ifdef PADDLE_WITH_DISTRIBUTE +paddle::Tensor BuildEmptyDistPaddleTensor( + const phi::distributed::ProcessMesh& process_mesh, + const phi::DDim& dims, + phi::DataType dtype) { + paddle::Tensor empty_tensor; + phi::DenseTensorMeta meta; + meta.dims = dims; + meta.dtype = dtype; + + auto dist_attr = phi::distributed::TensorDistAttr(phi::vectorize(dims)); + dist_attr.set_process_mesh(process_mesh); + + auto dist_t = std::make_shared( + std::make_shared( + std::make_shared( + nullptr, 0, phi::distributed::GetDefaultPlace()), + meta), + dist_attr); + empty_tensor.set_impl(dist_t); + empty_tensor.set_autograd_meta(std::make_shared()); + return empty_tensor; +} +#endif + +#ifdef PADDLE_WITH_DISTRIBUTE +std::tuple PrepareCtxForAutoParallel( + const paddle::OpMetaInfo& op_info, + bool is_forward, + bool is_double_grad, + paddle::CustomOpKernelContext& ctx) { // NOLINT + bool run_auto_parallel = false; + bool rank_is_in_current_mesh = true; + phi::distributed::ProcessMesh current_process_mesh; + + const auto& inputs = paddle::OpMetaInfoHelper::GetInputs(op_info); + const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(op_info); + const auto& inplace_map = paddle::OpMetaInfoHelper::GetInplaceMap(op_info); + + std::vector* all_inputs = ctx.AllMutableInput(); + std::vector x = *all_inputs; + const phi::distributed::ProcessMesh* mesh = nullptr; + for (auto& input : x) { + if (input.is_dist_tensor()) { + mesh = &( + std::dynamic_pointer_cast(input.impl()) + ->dist_attr() + .process_mesh()); + break; + } + } + + if (mesh) { + for (auto& input : x) { + if (input.is_dist_tensor()) { + PADDLE_ENFORCE_EQ( + std::dynamic_pointer_cast( + input.impl()) + ->dist_attr() + .process_mesh(), + *mesh, + phi::errors::InvalidArgument( + "Input %s has different mesh. However all inputs should " + "have the same mesh.", + input.name())); + } else { + PADDLE_ENFORCE_EQ( + phi::DenseTensor::classof(input.impl().get()), + true, + phi::errors::InvalidArgument("Failed to convert input %s impl " + "to phi::distributed::DistTensor " + "as it's not phi::DenseTensor.", + input.name())); + phi::distributed::TensorDistAttr dist_attr( + phi::vectorize(input.impl()->dims())); + dist_attr.set_process_mesh(*mesh); + auto dense_t = std::static_pointer_cast(input.impl()); + input.set_impl( + std::make_shared(dense_t, dist_attr)); + } + } + } + + run_auto_parallel = paddle::experimental::AllInputsAreDistTensor(x); + rank_is_in_current_mesh = true; + if (run_auto_parallel) { + auto mesh = + std::static_pointer_cast(x.at(0).impl()) + ->dist_attr() + .process_mesh(); + rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh); + + std::vector input_x(x.size()); + for (size_t i = 0; i < input_x.size(); ++i) { + input_x[i] = x.at(i).impl().get(); + } + + auto meta_dist_input_x = paddle::experimental::MakeDistMetaTensor(input_x); + auto spmd_info = + phi::distributed::VariadicReplicatedInferSpmdDynamic(meta_dist_input_x); + current_process_mesh = + paddle::holds_alternative( + spmd_info.first[0]) + ? paddle::get<0>(spmd_info.first[0]).process_mesh() + : paddle::get<1>(spmd_info.first[0]).at(0).process_mesh(); + + if (rank_is_in_current_mesh) { + auto* dev_ctx = phi::DeviceContextPool::Instance().Get(x.at(0).place()); + auto dist_input_x = paddle::experimental::ReshardApiInputToKernelInput( + dev_ctx, x, spmd_info.first[0]); + for (size_t i = 0; i < x.size(); ++i) { + all_inputs->at(i).set_impl( + std::make_shared(dist_input_x[i]->value())); + } + } else { + auto& infer_shape_func = + paddle::OpMetaInfoHelper::GetInferShapeFn(op_info); + auto& infer_dtype_func = + paddle::OpMetaInfoHelper::GetInferDtypeFn(op_info); + + std::vector> out_dims; + if (infer_shape_func) { + out_dims = RunInferShapeFunc( + ctx, infer_shape_func, inputs, outputs, inplace_map); + } else { + if (is_forward) { + out_dims = + RunDefaultInferShapeFunc(ctx, inputs, outputs, inplace_map); + } else { + out_dims = RunDefaultGradInferShapeFunc( + ctx, inputs, outputs, is_double_grad); + } + } + + std::vector> out_dtypes; + if (infer_dtype_func) { + out_dtypes = RunInferDtypeFunc( + ctx, infer_dtype_func, inputs, outputs, inplace_map); + } else { + if (is_forward) { + out_dtypes = + RunDefaultInferDtypeFunc(ctx, inputs, outputs, inplace_map); + } else { + out_dtypes = RunDefaultGradInferDtypeFunc( + ctx, inputs, outputs, is_double_grad); + } + } + + PADDLE_ENFORCE_EQ( + out_dims.size(), + ctx.OutputRange().size(), + phi::errors::InvalidArgument( + "Custome op infer_shape return size should be %d, but got %d.", + ctx.OutputRange().size(), + out_dims.size())); + + PADDLE_ENFORCE_EQ( + out_dtypes.size(), + ctx.OutputRange().size(), + phi::errors::InvalidArgument( + "Custome op infer_dtype return size should be %d, but got %d.", + ctx.OutputRange().size(), + out_dtypes.size())); + + for (size_t i = 0; i < out_dims.size(); ++i) { + const auto& out_dim = out_dims.at(i); + const auto& out_dtype = out_dtypes.at(i); + const auto& pair = ctx.OutputRangeAt(i); + PADDLE_ENFORCE_EQ( + out_dim.size(), + pair.second - pair.first, + phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " + "size should be %d, but got %d.", + i, + pair.second - pair.first, + out_dim.size())); + PADDLE_ENFORCE_EQ( + out_dtype.size(), + pair.second - pair.first, + phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " + "size should be %d, but got %d.", + i, + pair.second - pair.first, + out_dtype.size())); + + if (out_dim.size() == 1) { + *(ctx.MutableOutputAt(pair.first)) = BuildEmptyDistPaddleTensor( + current_process_mesh, out_dim[0], out_dtype[0]); + } else { + for (size_t j = pair.first; j < pair.second; j++) { + *(ctx.MutableOutputAt(j)) = BuildEmptyDistPaddleTensor( + current_process_mesh, out_dim[j], out_dtype[j]); + } + } + } + return std::tuple( + run_auto_parallel, rank_is_in_current_mesh, current_process_mesh); + } + } + return std::tuple( + run_auto_parallel, rank_is_in_current_mesh, current_process_mesh); +} +#endif + +#ifdef PADDLE_WITH_DISTRIBUTE +void TransCtxTensorsToDistTensors( + paddle::CustomOpKernelContext& ctx, // NOLINT + bool run_auto_parallel, + const phi::distributed::ProcessMesh& current_process_mesh) { + if (run_auto_parallel) { + std::vector* output_all = ctx.AllMutableOutput(); + for (size_t i = 0; i < output_all->size(); ++i) { + auto& tensor = output_all->at(i); + phi::distributed::TensorDistAttr dist_attr = + phi::distributed::TensorDistAttr(phi::vectorize(tensor.dims())); + dist_attr.set_process_mesh(current_process_mesh); + auto dist_t = std::make_shared( + std::dynamic_pointer_cast(tensor.impl()), + dist_attr); + tensor.set_impl(dist_t); + } + std::vector* input_all = ctx.AllMutableInput(); + for (size_t i = 0; i < input_all->size(); ++i) { + auto& tensor = input_all->at(i); + phi::distributed::TensorDistAttr dist_attr = + phi::distributed::TensorDistAttr(phi::vectorize(tensor.dims())); + dist_attr.set_process_mesh(current_process_mesh); + auto dist_t = std::make_shared( + std::dynamic_pointer_cast(tensor.impl()), + dist_attr); + tensor.set_impl(dist_t); + } + } +} +#endif + +void run_custom_op_impl(const paddle::OpMetaInfo& op_info, + bool is_forward, + bool is_double_grad, + paddle::CustomOpKernelContext& ctx) { // NOLINT + const auto& inputs = paddle::OpMetaInfoHelper::GetInputs(op_info); + const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(op_info); + const auto& inplace_map = paddle::OpMetaInfoHelper::GetInplaceMap(op_info); + ctx.ConstructInplaceIndex(inputs, outputs, inplace_map); + +#ifdef PADDLE_WITH_DISTRIBUTE + auto result = + PrepareCtxForAutoParallel(op_info, is_forward, is_double_grad, ctx); + bool run_auto_parallel = std::get<0>(result); + bool rank_is_in_current_mesh = std::get<1>(result); + phi::distributed::ProcessMesh current_process_mesh = std::get<2>(result); + if (!rank_is_in_current_mesh) { + return; + } +#endif + + std::vector* all_inputs = ctx.AllMutableInput(); + for (size_t i = 0; i < all_inputs->size(); ++i) { + auto& tensor = all_inputs->at(i); + if (tensor.initialized() && tensor.is_dense_tensor() && + !std::dynamic_pointer_cast(tensor.impl()) + ->meta() + .is_contiguous()) { + tensor.set_impl(std::make_shared( + std::move(paddle::experimental::Trans2Contiguous( + *(std::dynamic_pointer_cast(tensor.impl())))))); + } + } + + // handle inplace map + ctx.UpdatePlainOutputs(inputs, outputs, inplace_map); + VLOG(7) << "Begin run Kernel of Custom Op"; + (*paddle::OpMetaInfoHelper::GetKernelFn(op_info))(&ctx); + ctx.AssignInplaceOutputs(); + +#ifdef PADDLE_WITH_DISTRIBUTE + TransCtxTensorsToDistTensors(ctx, run_auto_parallel, current_process_mesh); +#endif +} + +} // namespace egr diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.h b/paddle/fluid/eager/custom_operator/custom_operator_utils.h new file mode 100644 index 0000000000000..ac2dec37f3d34 --- /dev/null +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.h @@ -0,0 +1,24 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/api/ext/op_meta_info.h" + +namespace egr { +void run_custom_op_impl(const paddle::OpMetaInfo& op_info, + bool is_forward, + bool is_double_grad, + paddle::CustomOpKernelContext& ctx); // NOLINT +} // namespace egr diff --git a/paddle/fluid/eager/utils.cc b/paddle/fluid/eager/utils.cc index 5851e3a0e33df..a21757de26b50 100644 --- a/paddle/fluid/eager/utils.cc +++ b/paddle/fluid/eager/utils.cc @@ -337,6 +337,10 @@ void EagerUtils::HandleViewBetweenInputAndOutput( std::dynamic_pointer_cast(input_tensor.impl()); if (view_output_tensor->impl() == nullptr) { view_output_tensor->set_impl(std::make_shared()); + } else { + PADDLE_ENFORCE(view_output_tensor->is_dense_tensor(), + phi::errors::Unavailable( + "DenseTensor can not be inplaced with other Tensor.")); } auto view_output_dense_tensor = std::dynamic_pointer_cast(view_output_tensor->impl()); @@ -344,6 +348,35 @@ void EagerUtils::HandleViewBetweenInputAndOutput( view_output_dense_tensor->ShareInplaceVersionCounterWith( *input_dense_tensor); + VLOG(4) << "Perform View between Output Tensor(" + << view_output_tensor->name() << ") and Input Tensor(" + << input_tensor.name() + << "), share allocation and inplace version."; + } else if (input_tensor.is_dist_tensor()) { + auto input_dense_tensor = + std::dynamic_pointer_cast( + input_tensor.impl()) + ->unsafe_mutable_value(); + if (view_output_tensor->impl() == nullptr) { + view_output_tensor->set_impl( + std::make_shared( + input_tensor.dims(), + std::dynamic_pointer_cast( + input_tensor.impl()) + ->dist_attr())); + } else { + PADDLE_ENFORCE(view_output_tensor->is_dist_tensor(), + phi::errors::Unavailable( + "DistTensor can not be inplaced with other Tensor.")); + } + auto view_output_dense_tensor = + std::dynamic_pointer_cast( + view_output_tensor->impl()) + ->unsafe_mutable_value(); + view_output_dense_tensor->ShareBufferWith(*input_dense_tensor); + view_output_dense_tensor->ShareInplaceVersionCounterWith( + *input_dense_tensor); + VLOG(4) << "Perform View between Output Tensor(" << view_output_tensor->name() << ") and Input Tensor(" << input_tensor.name() diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 0ab19bb65d35b..63e59bbcfeede 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -62,10 +62,16 @@ typedef SSIZE_T ssize_t; #include "paddle/fluid/pybind/cuda_streams_py.h" #endif +#include "paddle/fluid/eager/custom_operator/custom_operator_utils.h" #include "paddle/phi/api/include/operants_manager.h" #include "paddle/phi/api/include/tensor_operants.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/core/flags.h" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/api/lib/api_gen_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#endif PHI_DECLARE_string(tensor_operants_mode); @@ -535,6 +541,7 @@ static PyObject* eager_api_run_custom_op(PyObject* self, const auto& attrs = paddle::OpMetaInfoHelper::GetAttrs(vec_map[0]); const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(vec_map[0]); const auto& inplace_map = paddle::OpMetaInfoHelper::GetInplaceMap(vec_map[0]); + for (size_t i = 0; i < inputs.size(); ++i) { const auto& input = inputs.at(i); // Parse op_type first, so that use i + 1 @@ -552,17 +559,6 @@ static PyObject* eager_api_run_custom_op(PyObject* self, if (paddle::framework::detail::IsDuplicableVar(input)) { std::vector tensors = std::move(CastPyArg2VectorOfTensor(obj, i + 1)); // NOLINT - for (auto& tensor : tensors) { - if (tensor.initialized() && tensor.is_dense_tensor() && - !std::dynamic_pointer_cast(tensor.impl()) - ->meta() - .is_contiguous()) { - tensor.set_impl(std::make_shared( - std::move(paddle::experimental::Trans2Contiguous( - *(std::dynamic_pointer_cast( - tensor.impl())))))); - } - } ctx.EmplaceBackInputs(std::move(tensors)); VLOG(7) << "Custom operator add input " << input << " to CustomOpKernelContext. Add vector size = " @@ -570,19 +566,12 @@ static PyObject* eager_api_run_custom_op(PyObject* self, } else { paddle::Tensor tensor = std::move(CastPyArg2Tensor(obj, i + 1)); // NOLINT - if (tensor.initialized() && tensor.is_dense_tensor() && - !std::dynamic_pointer_cast(tensor.impl()) - ->meta() - .is_contiguous()) { - tensor.set_impl(std::make_shared( - std::move(paddle::experimental::Trans2Contiguous(*( - std::dynamic_pointer_cast(tensor.impl())))))); - } ctx.EmplaceBackInput(std::move(tensor)); VLOG(7) << "Custom operator add input " << input << " to CustomOpKernelContext. Add Tensor for general case."; } } + // Parse op_type and inputs first, so that use 1 + inputs.size() + i int attr_start_idx = static_cast(1 + inputs.size()); for (size_t i = 0; i < attrs.size(); ++i) { @@ -628,6 +617,7 @@ static PyObject* eager_api_run_custom_op(PyObject* self, attr_type_str)); } } + { eager_gil_scoped_release guard; ctx.ConstructInplaceIndex(inputs, outputs, inplace_map); @@ -671,11 +661,8 @@ static PyObject* eager_api_run_custom_op(PyObject* self, ctx.EmplaceBackOutput(std::move(InitializedEmptyTensor())); } - // handle inplace map - ctx.UpdatePlainOutputs(inputs, outputs, inplace_map); VLOG(7) << "Run Kernel of Custom Op: " << op_type; - (*paddle::OpMetaInfoHelper::GetKernelFn(vec_map[0]))(&ctx); - ctx.AssignInplaceOutputs(); + egr::run_custom_op_impl(vec_map[0], true, false, ctx); // handle optional None output when construct backward graph for (size_t i = 0; i < ctx.OutputRange().size(); i++) { @@ -684,7 +671,8 @@ static PyObject* eager_api_run_custom_op(PyObject* self, ctx.MutableOutputAt(ctx.OutputRangeAt(i).first); if (!out_tensor->initialized()) { PADDLE_ENFORCE( - paddle::framework::detail::IsOptionalVar(outputs.at(i)), + paddle::framework::detail::IsOptionalVar(outputs.at(i)) || + out_tensor->is_dist_tensor(), phi::errors::InvalidArgument( "Custom operator's %d-th output is not initialized. " "Please check your implementation again. If you are " diff --git a/paddle/phi/api/ext/op_meta_info.h b/paddle/phi/api/ext/op_meta_info.h index c774cafcfd26a..484ea06944653 100644 --- a/paddle/phi/api/ext/op_meta_info.h +++ b/paddle/phi/api/ext/op_meta_info.h @@ -120,15 +120,15 @@ class PADDLE_API CustomOpKernelContext { std::vector InputsBetween(size_t start, size_t end) const; Tensor& MutableInputAt(size_t idx); std::vector* AllMutableInput(); - paddle::optional OptionalInputAt(size_t idx); + paddle::optional OptionalInputAt(size_t idx) const; paddle::optional> OptionalInputsBetween(size_t start, - size_t end); + size_t end) const; const std::vector& Attrs() const; - const std::vector>& InputRange(); - const std::vector>& OutputRange(); + const std::vector>& InputRange() const; + const std::vector>& OutputRange() const; Tensor* MutableOutputAt(size_t idx); std::vector MutableOutputBetween(size_t start, size_t end); - std::vector OutputsBetween(size_t start, size_t end); + std::vector OutputsBetween(size_t start, size_t end) const; std::vector* AllMutableOutput(); template @@ -151,8 +151,8 @@ class PADDLE_API CustomOpKernelContext { const std::unordered_map& inplace_map); void AssignInplaceOutputs(); std::vector* AllMutablePlainOutput(); - std::unordered_map GetInplaceIndexMap(); - std::unordered_map GetInplaceReverseIndexMap(); + std::unordered_map GetInplaceIndexMap() const; + std::unordered_map GetInplaceReverseIndexMap() const; private: // TODO(chenweihang): replaced be SmallVector diff --git a/paddle/phi/api/lib/op_meta_info.cc b/paddle/phi/api/lib/op_meta_info.cc index da8b9125a71dd..14334aa7c42a6 100644 --- a/paddle/phi/api/lib/op_meta_info.cc +++ b/paddle/phi/api/lib/op_meta_info.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/enforce.h" namespace paddle { @@ -63,10 +64,12 @@ PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst) { "happens when handling inplace optional inputs & outputs."; return; } - PADDLE_ENFORCE_EQ(src.is_dense_tensor() && dst->is_dense_tensor(), - true, - phi::errors::Unavailable( - "Now only supported DenseTensor in Custom Operator.")); + PADDLE_ENFORCE_EQ( + ((src.is_dense_tensor() && dst->is_dense_tensor()) || + (src.is_dist_tensor() && dst->is_dist_tensor())), + true, + phi::errors::Unavailable( + "Now only supported DenseTensor and DistTensor in Custom Operator.")); PADDLE_ENFORCE_EQ( src.initialized(), true, @@ -76,9 +79,19 @@ PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst) { true, phi::errors::Unavailable( "The Custom OpKernel origin output is not defined.")); - auto& dense_src = static_cast(*src.impl()); - auto* dense_dst = static_cast(dst->impl().get()); - *dense_dst = dense_src; + if (src.is_dense_tensor()) { + auto& dense_src = static_cast(*src.impl()); + auto* dense_dst = static_cast(dst->impl().get()); + *dense_dst = dense_src; + } else { + auto* dense_src = + static_cast(src.impl().get()) + ->unsafe_mutable_value(); + auto* dense_dst = + static_cast(dst->impl().get()) + ->unsafe_mutable_value(); + *dense_dst = *dense_src; + } } ////////////////////// Kernel Context ////////////////////// @@ -149,7 +162,8 @@ std::vector* CustomOpKernelContext::AllMutableInput() { return &inputs_; } -paddle::optional CustomOpKernelContext::OptionalInputAt(size_t idx) { +paddle::optional CustomOpKernelContext::OptionalInputAt( + size_t idx) const { if (!inputs_.at(idx).is_initialized()) { return paddle::none; } @@ -157,7 +171,7 @@ paddle::optional CustomOpKernelContext::OptionalInputAt(size_t idx) { } paddle::optional> -CustomOpKernelContext::OptionalInputsBetween(size_t start, size_t end) { +CustomOpKernelContext::OptionalInputsBetween(size_t start, size_t end) const { std::vector rlt; for (size_t i = start; i < end; ++i) { if (!inputs_.at(i).is_initialized()) { @@ -181,7 +195,7 @@ std::vector CustomOpKernelContext::MutableOutputBetween(size_t start, } std::vector CustomOpKernelContext::OutputsBetween(size_t start, - size_t end) { + size_t end) const { std::vector rlt; for (size_t i = start; i < end; ++i) { rlt.emplace_back(outputs_.at(i)); @@ -203,12 +217,12 @@ const std::pair& CustomOpKernelContext::OutputRangeAt( } const std::vector>& -CustomOpKernelContext::InputRange() { +CustomOpKernelContext::InputRange() const { return input_range_; } const std::vector>& -CustomOpKernelContext::OutputRange() { +CustomOpKernelContext::OutputRange() const { return output_range_; } @@ -293,12 +307,13 @@ std::vector* CustomOpKernelContext::AllMutablePlainOutput() { return &plain_outputs_; } -std::unordered_map CustomOpKernelContext::GetInplaceIndexMap() { +std::unordered_map CustomOpKernelContext::GetInplaceIndexMap() + const { return inplace_idx_map_; } std::unordered_map -CustomOpKernelContext::GetInplaceReverseIndexMap() { +CustomOpKernelContext::GetInplaceReverseIndexMap() const { return inplace_reverse_idx_map_; } ////////////////////// Op Meta Info ////////////////////// diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index b17e012584d7f..c5fbf8466f2bf 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -122,7 +122,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_semi_auto_parallel_single_strategy MODULES test_semi_auto_parallel_single_strategy) set_tests_properties(test_semi_auto_parallel_single_strategy - PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 300) py_test_modules(test_semi_auto_parallel_hybrid_strategy MODULES test_semi_auto_parallel_hybrid_strategy) set_tests_properties(test_semi_auto_parallel_hybrid_strategy diff --git a/test/auto_parallel/semi_auto_parallel_for_custom_relu.py b/test/auto_parallel/semi_auto_parallel_for_custom_relu.py new file mode 100644 index 0000000000000..07496ec07e506 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_custom_relu.py @@ -0,0 +1,119 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from site import getsitepackages + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.utils.cpp_extension import get_build_directory, load +from paddle.utils.cpp_extension.extension_utils import IS_WINDOWS, run_cmd + +# Note(Aurelius84): We use `add_test` in Cmake to config how to run unittest in CI. +# `PYTHONPATH` will be set as `build/python/paddle` that will make no way to find +# paddle include directory. Because the following path is generated after installing +# PaddlePaddle whl. So here we specific `include_dirs` to avoid errors in CI. +paddle_includes = [] +for site_packages_path in getsitepackages(): + paddle_includes.append( + os.path.join(site_packages_path, 'paddle', 'include') + ) + paddle_includes.append( + os.path.join(site_packages_path, 'paddle', 'include', 'third_party') + ) + +# Test for extra compile args +extra_cc_args = ['-w', '-g'] if not IS_WINDOWS else ['/w'] +extra_nvcc_args = ['-O3'] + +# Because Windows don't use docker, the shared lib already exists in the +# cache dir, it will not be compiled again unless the shared lib is removed. +file = f'{get_build_directory()}\\dist_custom_relu\\dist_custom_relu.pyd' +if os.name == 'nt' and os.path.isfile(file): + cmd = f'del {file}' + run_cmd(cmd, True) + +if os.name == 'nt': + test_include = "..\\python\\paddle\\base\\tests\\auto_parallel" +else: + test_include = "../python/paddle/base/tests/auto_parallel" +paddle_includes.append(test_include) + +custom_ops = load( + name='dist_custom_relu_jit', + sources=[ + '../custom_op/custom_relu_op.cc', + '../custom_op/custom_relu_op_dup.cc', + '../custom_op/custom_relu_op.cu', + ], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cxx_cflags=extra_cc_args, # test for cc flags + extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags + verbose=True, +) + + +class TestCustomReluForSemiAutoParallel: + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def check_tensor_eq(self, a, b): + np1 = a.numpy() + np2 = b.numpy() + np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + + def test_body(self, x_shape, x_specs): + paddle.seed(self._seed) + np.random.seed(self._seed) + + x_np = np.random.random(size=x_shape).astype(self._dtype) + x = paddle.to_tensor(x_np) + x.stop_gradient = False + + x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs) + + dist_x = dist.shard_tensor(x_np, dist_attr=x_dist_attr) + dist_x.stop_gradient = False + + y = paddle.add(x, x) + dist_y = paddle.add(dist_x, dist_x) + out = custom_ops.custom_relu(y) + dist_out = custom_ops.custom_relu(dist_y) + out.stop_gradient = False + dist_out.stop_gradient = False + + self.check_tensor_eq(out, dist_out) + + out.backward() + dist_out.backward() + self.check_tensor_eq(x.grad, dist_x.grad) + + def test_custom_relu(self): + self.test_body( + x_shape=[64, 32], + x_specs=['x', None], + ) + + def run_test_case(self): + paddle.set_device("gpu:" + str(dist.get_rank())) + self.test_custom_relu() + + +if __name__ == '__main__': + TestCustomReluForSemiAutoParallel().test_custom_relu() diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py b/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py new file mode 100644 index 0000000000000..ef8ff6e004c45 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py @@ -0,0 +1,161 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from site import getsitepackages + +from semi_auto_parallel_simple_net import TestSimpleNetForSemiAutoParallel + +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import nn +from paddle.utils.cpp_extension import get_build_directory, load +from paddle.utils.cpp_extension.extension_utils import IS_WINDOWS, run_cmd + +# Note(Aurelius84): We use `add_test` in Cmake to config how to run unittest in CI. +# `PYTHONPATH` will be set as `build/python/paddle` that will make no way to find +# paddle include directory. Because the following path is generated after installing +# PaddlePaddle whl. So here we specific `include_dirs` to avoid errors in CI. +paddle_includes = [] +for site_packages_path in getsitepackages(): + paddle_includes.append( + os.path.join(site_packages_path, 'paddle', 'include') + ) + paddle_includes.append( + os.path.join(site_packages_path, 'paddle', 'include', 'third_party') + ) + +# Test for extra compile args +extra_cc_args = ['-w', '-g'] if not IS_WINDOWS else ['/w'] +extra_nvcc_args = ['-O3'] + +# Because Windows don't use docker, the shared lib already exists in the +# cache dir, it will not be compiled again unless the shared lib is removed. +file = f'{get_build_directory()}\\dist_custom_relu\\dist_custom_relu.pyd' +if os.name == 'nt' and os.path.isfile(file): + cmd = f'del {file}' + run_cmd(cmd, True) + +if os.name == 'nt': + test_include = "..\\python\\paddle\\base\\tests\\auto_parallel" +else: + test_include = "../python/paddle/base/tests/auto_parallel" +paddle_includes.append(test_include) + +custom_ops = load( + name='dist_custom_relu_jit', + sources=[ + '../custom_op/custom_relu_op.cc', + '../custom_op/custom_relu_op_dup.cc', + '../custom_op/custom_relu_op.cu', + ], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cxx_cflags=extra_cc_args, # test for cc flags + extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags + verbose=True, +) + +BATCH_SIZE = 16 +BATCH_NUM = 4 +IMAGE_SIZE = 784 +CLASS_NUM = 10 + + +class PPDemoNet(nn.Layer): + def __init__(self, mesh0, mesh1, param_suffix=""): + super().__init__() + self.replicate_dist_attr0 = dist.DistAttr( + mesh=mesh0, sharding_specs=[None, None] + ) + self.replicate_dist_attr1 = dist.DistAttr( + mesh=mesh1, sharding_specs=[None, None] + ) + self.w0 = dist.shard_tensor( + self.create_parameter( + shape=[IMAGE_SIZE, IMAGE_SIZE], + attr=paddle.framework.ParamAttr( + name="pp_demo_weight_0" + param_suffix, + initializer=paddle.nn.initializer.Uniform(0, 1), + ), + ), + dist_attr=self.replicate_dist_attr0, + ) + self.w1 = dist.shard_tensor( + self.create_parameter( + shape=[IMAGE_SIZE, CLASS_NUM], + attr=paddle.framework.ParamAttr( + name="pp_nemo_weight_1" + param_suffix, + initializer=paddle.nn.initializer.Uniform(0, 1), + ), + ), + dist_attr=self.replicate_dist_attr1, + ) + + def forward(self, x): + out = F.linear(x, self.w0) + out = custom_ops.custom_relu(out) + # out = F.relu(out) + out = dist.reshard(out, dist_attr=self.replicate_dist_attr1) + out = F.linear(out, self.w1) + return out + + +class TestSimpleNetWithCustomReluForSemiAutoParallel( + TestSimpleNetForSemiAutoParallel +): + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + self._pp_mesh0 = dist.ProcessMesh([0], dim_names=["x"]) + self._pp_mesh1 = dist.ProcessMesh([1], dim_names=["x"]) + + paddle.set_device(self._backend) + self.init_input_data() + + def run_dynamic_custom_relu(self, layer, shard_input=False): + # create loss + loss_fn = nn.MSELoss() + # run forward and backward + image = paddle.to_tensor(self.image) + if shard_input: + image = dist.shard_tensor( + image, + dist_attr=dist.DistAttr( + mesh=self._mesh, sharding_specs=['x', None] + ), + ) + out = layer(image) + + label = paddle.to_tensor(self.label) + loss = loss_fn(out, label) + + loss.backward() + + def test_demo_net(self): + mp_layer = dist.shard_layer( + PPDemoNet(self._pp_mesh0, self._pp_mesh1), + self._mesh, + self.shard_fn, + ) + self.run_dynamic_custom_relu(mp_layer) + + def run_test_case(self): + self.test_demo_net() + + +if __name__ == "__main__": + TestSimpleNetWithCustomReluForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py index b1132a6a3a8dc..2589566cb670e 100644 --- a/test/auto_parallel/test_semi_auto_parallel_basic.py +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -96,6 +96,16 @@ def test_add_n_api(self): user_defined_envs=envs, ) + def test_custom_relu_api(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_custom_relu.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py index 27bff7eda64fa..d4d1418e831eb 100644 --- a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py +++ b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py @@ -112,6 +112,17 @@ def test_simple_net_zero_grads(self): user_defined_envs=envs, ) + def test_simple_net_custom_relu(self): + self._changeable_envs = {"backend": ["gpu"]} + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_simple_net_custom_relu.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main()