Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix bug for binary_cross_entropy_with_logits loss #54869

Merged
merged 4 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1817,15 +1817,16 @@
data_type : x

- backward_op : sigmoid_cross_entropy_with_logits_grad
forward : sigmoid_cross_entropy_with_logits (Tensor x, Tensor label, bool normalize=false, int ignore_index=-100) -> Tensor(out)
args : (Tensor x, Tensor label, Tensor out_grad, bool normalize, int ignore_index)
forward : sigmoid_cross_entropy_with_logits (Tensor x, Tensor label, Tensor pos_weight, bool normalize=false, int ignore_index=-100) -> Tensor(out)
args : (Tensor x, Tensor label, Tensor pos_weight, Tensor out_grad, bool normalize, int ignore_index)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : sigmoid_cross_entropy_with_logits_grad
inplace : (out_grad -> x_grad)
optional : pos_weight

- backward_op : sigmoid_double_grad
forward : sigmoid_grad (Tensor out, Tensor fwd_grad_out) -> Tensor(grad_x)
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2087,14 +2087,15 @@
backward : sigmoid_grad

- op : sigmoid_cross_entropy_with_logits
args : (Tensor x, Tensor label, bool normalize=false, int ignore_index=-100)
args : (Tensor x, Tensor label, Tensor pos_weight, bool normalize=false, int ignore_index=-100)
output : Tensor
infer_meta :
func : SigmoidCrossEntropyWithLogitsInferMeta
kernel :
func : sigmoid_cross_entropy_with_logits
inplace : (x -> out)
backward : sigmoid_cross_entropy_with_logits_grad
optional : pos_weight

- op : sign
args : (Tensor x)
Expand Down Expand Up @@ -2495,7 +2496,7 @@
func : WeightedSampleNeighborsInferMeta
kernel :
func : weighted_sample_neighbors
optional: eids
optional : eids

- op : where
args : (Tensor condition, Tensor x, Tensor y)
Expand Down
41 changes: 0 additions & 41 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2672,47 +2672,6 @@ void SegmentPoolInferMeta(const MetaTensor& x,
}
}

void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
const MetaTensor& label,
bool normalize,
int ignore_index,
MetaTensor* out,
MetaConfig config) {
auto x_dims = x.dims();
auto labels_dims = label.dims();
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank,
labels_dims.size(),
phi::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same rank."
"But received: the rank of Input(X) is [%d], "
"the rank of Input(Label) is [%d].",
rank,
labels_dims.size()));

bool check = true;
if ((!config.is_runtime) &&
(phi::product(x_dims) <= 0 || phi::product(labels_dims) <= 0)) {
check = false;
}

if (check) {
PADDLE_ENFORCE_EQ(
phi::slice_ddim(x_dims, 0, rank),
phi::slice_ddim(labels_dims, 0, rank),
phi::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension. But received: the shape of "
"Input(X) is [%s], the shape of Input(Label) is [%s].",
x_dims,
labels_dims));
}

out->set_dims(x_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}

void TakeAlongAxisInferMeta(const MetaTensor& x,
const MetaTensor& index,
int axis,
Expand Down
7 changes: 0 additions & 7 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,13 +417,6 @@ void SegmentPoolInferMeta(const MetaTensor& x,
MetaTensor* summed_ids,
MetaConfig config = MetaConfig());

void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
const MetaTensor& label,
bool normalize,
int ignore_index,
MetaTensor* out,
MetaConfig config = MetaConfig());

void TakeAlongAxisInferMeta(const MetaTensor& x,
const MetaTensor& index,
int axis,
Expand Down
56 changes: 56 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2771,6 +2771,61 @@ void SgdInferMeta(const MetaTensor& param,
}
}

void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
const MetaTensor& label,
const MetaTensor& pos_weight,
bool normalize,
int ignore_index,
MetaTensor* out,
MetaConfig config) {
auto x_dims = x.dims();
auto labels_dims = label.dims();
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank,
labels_dims.size(),
phi::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same rank."
"But received: the rank of Input(X) is [%d], "
"the rank of Input(Label) is [%d].",
rank,
labels_dims.size()));

bool check = true;
if ((!config.is_runtime) &&
(phi::product(x_dims) <= 0 || phi::product(labels_dims) <= 0)) {
check = false;
}

if (check) {
PADDLE_ENFORCE_EQ(
phi::slice_ddim(x_dims, 0, rank),
phi::slice_ddim(labels_dims, 0, rank),
phi::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension. But received: the shape of "
"Input(X) is [%s], the shape of Input(Label) is [%s].",
x_dims,
labels_dims));

if (pos_weight) {
auto weight_dims = pos_weight.dims();
PADDLE_ENFORCE_EQ(
phi::slice_ddim(weight_dims, 0, rank),
phi::slice_ddim(labels_dims, 0, rank),
phi::errors::InvalidArgument(
"Input(pos_weight) and Input(Label) shall have the same shape "
"But received: the shape of Input(PosWeight) is [%s], "
"the shape of Input(Label) is [%s].",
weight_dims,
labels_dims));
}
}

out->set_dims(x_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}

void SendUERecvInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
Expand Down Expand Up @@ -3410,5 +3465,6 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
out_count->set_dims({-1});
out_count->set_dtype(DataType::INT32);
}

} // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,14 @@ void SgdInferMeta(const MetaTensor& param,
MetaTensor* param_out,
MetaTensor* master_param_out);

void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
const MetaTensor& label,
const MetaTensor& pos_weight,
bool normalize,
int ignore_index,
MetaTensor* out,
MetaConfig config = MetaConfig());

void StackInferMeta(const std::vector<const MetaTensor*>& x,
int axis,
MetaTensor* out,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,35 @@
namespace phi {

template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const DenseTensor& out_grad,
bool normalize,
int ignore_index,
DenseTensor* in_grad) {
void SigmoidCrossEntropyWithLogitsGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const paddle::optional<DenseTensor>& pos_weight,
const DenseTensor& out_grad,
bool normalize,
int ignore_index,
DenseTensor* in_grad) {
auto dx_data = dev_ctx.template Alloc<T>(in_grad);

int limit = in_grad->numel();
auto x_data = x.data<T>();
auto label_data = label.data<T>();
auto dout_data = out_grad.data<T>();
auto pos_weight_data =
(pos_weight.get_ptr() == nullptr ? nullptr
: pos_weight.get_ptr()->data<T>());

for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
T label = label_data[idx];
T dout = dout_data[idx];
T pos_weight_idx = pos_weight_data == nullptr ? 1 : pos_weight_data[idx];
if (static_cast<int>(label) == ignore_index) {
dx_data[idx] = static_cast<T>(0.);
} else {
T simoid_x = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-x));
T diff = simoid_x - label;
T diff = simoid_x * pos_weight_idx - label;
dx_data[idx] = dout * diff;
}
}
Expand Down
21 changes: 14 additions & 7 deletions paddle/phi/kernels/cpu/sigmoid_cross_entropy_with_logits_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,33 @@
namespace phi {

template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
bool normalize,
int ignore_index,
DenseTensor* out) {
void SigmoidCrossEntropyWithLogitsKernel(
const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const paddle::optional<DenseTensor>& pos_weight,
bool normalize,
int ignore_index,
DenseTensor* out) {
auto out_data = dev_ctx.template Alloc<T>(out);
int limit = out->numel();
auto x_data = x.data<T>();
auto label_data = label.data<T>();
auto pos_weight_data =
(pos_weight.get_ptr() == nullptr ? nullptr
: pos_weight.get_ptr()->data<T>());

for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
T label = label_data[idx];
if (static_cast<int>(label) == ignore_index) {
out_data[idx] = static_cast<T>(0.);
} else {
T pos_weight_idx = pos_weight_data == nullptr ? 1 : pos_weight_data[idx];
T term1 = (x > 0) ? x : 0;
T term2 = x * label;
T term3 = std::log(static_cast<T>(1) + std::exp(-std::abs(x)));
out_data[idx] = term1 - term2 + term3;
out_data[idx] = term1 - term2 + term3 * pos_weight_idx;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,50 @@ struct SigmoidBwdFunctor {
}
};

template <typename T>
struct SigmoidBwdPosWeightFunctor {
T ignore_index_;
T eps = static_cast<T>(1e-5);

HOSTDEVICE inline SigmoidBwdPosWeightFunctor(const T ignore_index)
: ignore_index_(ignore_index) {}

HOSTDEVICE inline phi::Array<T, 2> operator()(const T x,
const T label,
const T pos_weight,
const T dout) {
T counts;
T dx_data;

T diff = label - static_cast<T>(ignore_index_);
if ((diff > -eps) && (diff < eps)) {
dx_data = static_cast<T>(0.);
counts = 0;
} else {
T simoid_x =
static_cast<T>(1) / (static_cast<T>(1) + phi::funcs::real_exp(-x));
T diff = simoid_x * pos_weight - label;
dx_data = dout * diff;
counts = 1;
}
phi::Array<T, 2> outs;

outs[0] = dx_data;
outs[1] = counts;
return outs;
}
};

template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &label,
const DenseTensor &out_grad,
bool normalize,
int ignore_index,
DenseTensor *in_grad) {
void SigmoidCrossEntropyWithLogitsGradKernel(
const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &label,
const paddle::optional<DenseTensor> &pos_weight,
const DenseTensor &out_grad,
bool normalize,
int ignore_index,
DenseTensor *in_grad) {
auto dx_data = dev_ctx.template Alloc<T>(in_grad);

// Temporary memory
Expand All @@ -70,11 +106,19 @@ void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx,
dev_ctx.template Alloc<T>(counts_tensor);
counts_tensor->Resize(in_grad->dims());

std::vector<const DenseTensor *> ins = {&x, &label, &out_grad};
std::vector<DenseTensor *> outs = {in_grad, counts_tensor};
auto functor = SigmoidBwdFunctor<T>(ignore_index);
phi::funcs::ElementwiseKernel<T, decltype(functor), 2>(
dev_ctx, ins, &outs, functor);
if (pos_weight.get_ptr() == nullptr) {
std::vector<const DenseTensor *> ins = {&x, &label, &out_grad};
auto functor = SigmoidBwdFunctor<T>(ignore_index);
phi::funcs::ElementwiseKernel<T, decltype(functor), 2>(
dev_ctx, ins, &outs, functor);
} else {
std::vector<const DenseTensor *> ins = {
&x, &label, pos_weight.get_ptr(), &out_grad};
auto functor = SigmoidBwdPosWeightFunctor<T>(ignore_index);
phi::funcs::ElementwiseKernel<T, decltype(functor), 2>(
dev_ctx, ins, &outs, functor);
}
if (normalize) {
DenseTensor *norm_tensor = new DenseTensor();
norm_tensor->Resize({sizeof(T)});
Expand Down
Loading