Skip to content

Commit

Permalink
part-2 fix PR (PaddlePaddle#54389)
Browse files Browse the repository at this point in the history
  • Loading branch information
wentaoyu committed Nov 28, 2023
1 parent ce3559d commit 7f11a73
Showing 1 changed file with 0 additions and 108 deletions.
108 changes: 0 additions & 108 deletions paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,114 +312,6 @@ class MultiplyDoubleGradNode : public egr::GradNodeBase {
int axis_ = -1;
};

class MultiplyGradNode : public egr::GradNodeBase {
public:
MultiplyGradNode() : egr::GradNodeBase() {}
MultiplyGradNode(size_t bwd_in_slot_num, size_t bwd_out_slot_num)
: egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {}
~MultiplyGradNode() override = default;

virtual paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>
operator()(paddle::small_vector<std::vector<paddle::Tensor>, // NOLINT
egr::kSlotSmallVectorSize>& grads, // NOLINT
bool create_graph = false,
bool is_new_grad = false) override;
std::string name() override { return "MultiplyGradNode"; }

void ClearTensorWrappers() override {
x_.clear();
y_.clear();

SetIsTensorWrappersCleared(true);
}

std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node =
std::shared_ptr<MultiplyGradNode>(new MultiplyGradNode(*this));
return copied_node;
}

// SetTensorWrapperX, SetTensorWrapperY, ...
void SetTensorWrapperx(const paddle::Tensor& x) {
x_ = egr::TensorWrapper(x, false);
}
void SetTensorWrappery(const paddle::Tensor& y) {
y_ = egr::TensorWrapper(y, false);
}

void SetTensorWrapperNoNeedBufferx(const paddle::Tensor& x) {
x_ = egr::TensorWrapper(x, true);
}
void SetTensorWrapperNoNeedBuffery(const paddle::Tensor& y) {
y_ = egr::TensorWrapper(y, true);
}

// SetAttributes
void SetAttributeaxis(const int& axis) { axis_ = axis; }

private:
// TensorWrappers
egr::TensorWrapper x_;
egr::TensorWrapper y_;

// Attributes
int axis_ = -1;
};

class MultiplyDoubleGradNode : public egr::GradNodeBase {
public:
MultiplyDoubleGradNode() : egr::GradNodeBase() {}
MultiplyDoubleGradNode(size_t bwd_in_slot_num, size_t bwd_out_slot_num)
: egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {}
~MultiplyDoubleGradNode() override = default;

virtual paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>
operator()(paddle::small_vector<std::vector<paddle::Tensor>, // NOLINT
egr::kSlotSmallVectorSize>& grads, // NOLINT
bool create_graph = false,
bool is_new_grad = false) override;
std::string name() override { return "MultiplyDoubleGradNode"; }

void ClearTensorWrappers() override {
x_.clear();
y_.clear();
grad_out_.clear();

SetIsTensorWrappersCleared(true);
}

std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node = std::shared_ptr<MultiplyDoubleGradNode>(
new MultiplyDoubleGradNode(*this));
return copied_node;
}

// SetTensorWrapperX, SetTensorWrapperY, ...
void SetTensorWrapperx(const paddle::Tensor& x) {
x_ = egr::TensorWrapper(x, false);
}
void SetTensorWrappery(const paddle::Tensor& y) {
y_ = egr::TensorWrapper(y, false);
}
void SetTensorWrappergrad_out(const paddle::Tensor& grad_out) {
grad_out_ = egr::TensorWrapper(grad_out, false);
}

// SetAttributes
void SetAttributeaxis(const int& axis) { axis_ = axis; }

private:
// TensorWrappers
egr::TensorWrapper x_;
egr::TensorWrapper y_;
egr::TensorWrapper grad_out_;

// Attributes
int axis_ = -1;
};

class MultiplyTripleGradNode : public egr::GradNodeBase {
public:
MultiplyTripleGradNode() : egr::GradNodeBase() {}
Expand Down

0 comments on commit 7f11a73

Please sign in to comment.