Skip to content

Commit

Permalink
force sync batch norm grad sequential (#52268)
Browse files Browse the repository at this point in the history
* force sync batch norm grad sequential
  • Loading branch information
wanghuancoder authored Mar 30, 2023
1 parent 551ff88 commit 336160c
Show file tree
Hide file tree
Showing 9 changed files with 1,435 additions and 4 deletions.
37 changes: 37 additions & 0 deletions paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,40 @@ paddle::Tensor conv2d_ad_func(const paddle::Tensor& input,
std::vector<int> dilations,
int groups,
std::string data_format);

std::tuple<paddle::Tensor,
paddle::Tensor&,
paddle::Tensor&,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor>
sync_batch_norm__ad_func(const paddle::Tensor& x,
paddle::Tensor& mean, // NOLINT
paddle::Tensor& variance, // NOLINT
const paddle::Tensor& scale,
const paddle::Tensor& bias,
bool is_test,
float momentum,
float epsilon,
std::string data_layout,
bool use_global_stats,
bool trainable_statistics);
namespace sparse {
std::tuple<paddle::Tensor,
paddle::Tensor&,
paddle::Tensor&,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor>
sync_batch_norm__ad_func(const paddle::Tensor& x,
paddle::Tensor& mean, // NOLINT
paddle::Tensor& variance, // NOLINT
const paddle::Tensor& scale,
const paddle::Tensor& bias,
bool is_test,
float momentum,
float epsilon,
std::string data_layout,
bool use_global_stats,
bool trainable_statistics);
} // namespace sparse
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
set(eager_manual_functions
${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/forwards/add_n_fwd_func.cc
${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/forwards/conv2d_fwd_function.cc
${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/forwards/sync_batch_norm_fwd_func.cc
PARENT_SCOPE)

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
set(eager_manual_nodes
${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/nodes/conv2d_nodes.cc
${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/nodes/add_n_node.cc
${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/nodes/sync_batch_norm_node.cc
PARENT_SCOPE)
171 changes: 171 additions & 0 deletions paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,174 @@ class AddNGradNodeFinal : public egr::GradNodeBase {

// Attributes
};

class SyncBatchNormGradNode : public egr::GradNodeBase {
public:
SyncBatchNormGradNode() : egr::GradNodeBase() {}
SyncBatchNormGradNode(size_t bwd_in_slot_num, size_t bwd_out_slot_num)
: egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {}
~SyncBatchNormGradNode() override = default;

virtual paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>
operator()(paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>& grads, // NOLINT
bool create_graph = false,
bool is_new_grad = false) override;
std::string name() override { return "SyncBatchNormGradNode"; }

void ClearTensorWrappers() override {
x_.clear();
scale_.clear();
bias_.clear();
saved_mean_.clear();
saved_variance_.clear();
reserve_space_.clear();

SetIsTensorWrappersCleared(true);
}

std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node = std::shared_ptr<SyncBatchNormGradNode>(
new SyncBatchNormGradNode(*this));
return copied_node;
}

// SetTensorWrapperX, SetTensorWrapperY, ...
void SetTensorWrapperx(const paddle::Tensor& x) {
x_ = egr::TensorWrapper(x, false);
}
void SetTensorWrapperscale(const paddle::Tensor& scale) {
scale_ = egr::TensorWrapper(scale, false);
}
void SetTensorWrapperbias(const paddle::Tensor& bias) {
bias_ = egr::TensorWrapper(bias, false);
}
void SetTensorWrappersaved_mean(const paddle::Tensor& saved_mean) {
saved_mean_ = egr::TensorWrapper(saved_mean, false);
}
void SetTensorWrappersaved_variance(const paddle::Tensor& saved_variance) {
saved_variance_ = egr::TensorWrapper(saved_variance, false);
}
void SetTensorWrapperreserve_space(const paddle::Tensor& reserve_space) {
reserve_space_ = egr::TensorWrapper(reserve_space, false);
}

// SetAttributes
void SetAttributemomentum(const float& momentum) { momentum_ = momentum; }
void SetAttributeepsilon(const float& epsilon) { epsilon_ = epsilon; }
void SetAttributedata_layout(const std::string& data_layout) {
data_layout_ = data_layout;
}
void SetAttributeis_test(const bool& is_test) { is_test_ = is_test; }
void SetAttributeuse_global_stats(const bool& use_global_stats) {
use_global_stats_ = use_global_stats;
}
void SetAttributetrainable_statistics(const bool& trainable_statistics) {
trainable_statistics_ = trainable_statistics;
}

private:
// TensorWrappers
egr::TensorWrapper x_;
egr::TensorWrapper scale_;
egr::TensorWrapper bias_;
egr::TensorWrapper saved_mean_;
egr::TensorWrapper saved_variance_;
egr::TensorWrapper reserve_space_;

// Attributes
float momentum_;
float epsilon_;
std::string data_layout_;
bool is_test_;
bool use_global_stats_;
bool trainable_statistics_;
};

namespace sparse {
class SyncBatchNormGradNode : public egr::GradNodeBase {
public:
SyncBatchNormGradNode() : egr::GradNodeBase() {}
SyncBatchNormGradNode(size_t bwd_in_slot_num, size_t bwd_out_slot_num)
: egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {}
~SyncBatchNormGradNode() override = default;

virtual paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>
operator()(paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>& grads, // NOLINT
bool create_graph = false,
bool is_new_grad = false) override;
std::string name() override { return "SyncBatchNormGradNode"; }

void ClearTensorWrappers() override {
x_.clear();
scale_.clear();
bias_.clear();
saved_mean_.clear();
saved_variance_.clear();
reserve_space_.clear();

SetIsTensorWrappersCleared(true);
}

std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node = std::shared_ptr<SyncBatchNormGradNode>(
new SyncBatchNormGradNode(*this));
return copied_node;
}

// SetTensorWrapperX, SetTensorWrapperY, ...
void SetTensorWrapperx(const paddle::Tensor& x) {
x_ = egr::TensorWrapper(x, false);
}
void SetTensorWrapperscale(const paddle::Tensor& scale) {
scale_ = egr::TensorWrapper(scale, false);
}
void SetTensorWrapperbias(const paddle::Tensor& bias) {
bias_ = egr::TensorWrapper(bias, false);
}
void SetTensorWrappersaved_mean(const paddle::Tensor& saved_mean) {
saved_mean_ = egr::TensorWrapper(saved_mean, false);
}
void SetTensorWrappersaved_variance(const paddle::Tensor& saved_variance) {
saved_variance_ = egr::TensorWrapper(saved_variance, false);
}
void SetTensorWrapperreserve_space(const paddle::Tensor& reserve_space) {
reserve_space_ = egr::TensorWrapper(reserve_space, false);
}

// SetAttributes
void SetAttributemomentum(const float& momentum) { momentum_ = momentum; }
void SetAttributeepsilon(const float& epsilon) { epsilon_ = epsilon; }
void SetAttributedata_layout(const std::string& data_layout) {
data_layout_ = data_layout;
}
void SetAttributeis_test(const bool& is_test) { is_test_ = is_test; }
void SetAttributeuse_global_stats(const bool& use_global_stats) {
use_global_stats_ = use_global_stats;
}
void SetAttributetrainable_statistics(const bool& trainable_statistics) {
trainable_statistics_ = trainable_statistics;
}

private:
// TensorWrappers
egr::TensorWrapper x_;
egr::TensorWrapper scale_;
egr::TensorWrapper bias_;
egr::TensorWrapper saved_mean_;
egr::TensorWrapper saved_variance_;
egr::TensorWrapper reserve_space_;

// Attributes
float momentum_;
float epsilon_;
std::string data_layout_;
bool is_test_;
bool use_global_stats_;
bool trainable_statistics_;
};

} // namespace sparse
Loading

0 comments on commit 336160c

Please sign in to comment.