Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix bug for binary_cross_entropy_with_logits loss #54869

Merged
merged 4 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1817,15 +1817,16 @@
data_type : x

- backward_op : sigmoid_cross_entropy_with_logits_grad
forward : sigmoid_cross_entropy_with_logits (Tensor x, Tensor label, bool normalize=false, int ignore_index=-100) -> Tensor(out)
args : (Tensor x, Tensor label, Tensor out_grad, bool normalize, int ignore_index)
forward : sigmoid_cross_entropy_with_logits (Tensor x, Tensor label, Tensor pos_weight, bool normalize=false, int ignore_index=-100) -> Tensor(out)
args : (Tensor x, Tensor label, Tensor pos_weight, Tensor out_grad, bool normalize, int ignore_index)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : sigmoid_cross_entropy_with_logits_grad
inplace : (out_grad -> x_grad)
optional : pos_weight

- backward_op : sigmoid_double_grad
forward : sigmoid_grad (Tensor out, Tensor fwd_grad_out) -> Tensor(grad_x)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2800,7 +2800,7 @@
- op: sigmoid_cross_entropy_with_logits
backward: sigmoid_cross_entropy_with_logits_grad
inputs :
{x: X, label: Label}
{x: X, label: Label, pos_weight: PosWeight}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不用改,这是为了兼容原先的写法的,新增参数,不需要兼容旧的,后续新IR重构后原先驼峰式的命名写法都会删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

outputs :
out : Out

Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2087,14 +2087,15 @@
backward : sigmoid_grad

- op : sigmoid_cross_entropy_with_logits
args : (Tensor x, Tensor label, bool normalize=false, int ignore_index=-100)
args : (Tensor x, Tensor label, Tensor pos_weight, bool normalize=false, int ignore_index=-100)
output : Tensor
infer_meta :
func : SigmoidCrossEntropyWithLogitsInferMeta
kernel :
func : sigmoid_cross_entropy_with_logits
inplace : (x -> out)
backward : sigmoid_cross_entropy_with_logits_grad
optional : pos_weight

- op : sign
args : (Tensor x)
Expand Down Expand Up @@ -2495,7 +2496,7 @@
func : WeightedSampleNeighborsInferMeta
kernel :
func : weighted_sample_neighbors
optional: eids
optional : eids

- op : where
args : (Tensor condition, Tensor x, Tensor y)
Expand Down
41 changes: 0 additions & 41 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2672,47 +2672,6 @@ void SegmentPoolInferMeta(const MetaTensor& x,
}
}

void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
const MetaTensor& label,
bool normalize,
int ignore_index,
MetaTensor* out,
MetaConfig config) {
auto x_dims = x.dims();
auto labels_dims = label.dims();
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank,
labels_dims.size(),
phi::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same rank."
"But received: the rank of Input(X) is [%d], "
"the rank of Input(Label) is [%d].",
rank,
labels_dims.size()));

bool check = true;
if ((!config.is_runtime) &&
(phi::product(x_dims) <= 0 || phi::product(labels_dims) <= 0)) {
check = false;
}

if (check) {
PADDLE_ENFORCE_EQ(
phi::slice_ddim(x_dims, 0, rank),
phi::slice_ddim(labels_dims, 0, rank),
phi::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension. But received: the shape of "
"Input(X) is [%s], the shape of Input(Label) is [%s].",
x_dims,
labels_dims));
}

out->set_dims(x_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}

void TakeAlongAxisInferMeta(const MetaTensor& x,
const MetaTensor& index,
int axis,
Expand Down
7 changes: 0 additions & 7 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,13 +417,6 @@ void SegmentPoolInferMeta(const MetaTensor& x,
MetaTensor* summed_ids,
MetaConfig config = MetaConfig());

void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
const MetaTensor& label,
bool normalize,
int ignore_index,
MetaTensor* out,
MetaConfig config = MetaConfig());

void TakeAlongAxisInferMeta(const MetaTensor& x,
const MetaTensor& index,
int axis,
Expand Down
56 changes: 56 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3410,5 +3410,61 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
out_count->set_dims({-1});
out_count->set_dtype(DataType::INT32);
}

void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数位置按字典序放置

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

const MetaTensor& label,
const MetaTensor& pos_weight,
bool normalize,
int ignore_index,
MetaTensor* out,
MetaConfig config) {
auto x_dims = x.dims();
auto labels_dims = label.dims();
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank,
labels_dims.size(),
phi::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same rank."
"But received: the rank of Input(X) is [%d], "
"the rank of Input(Label) is [%d].",
rank,
labels_dims.size()));

bool check = true;
if ((!config.is_runtime) &&
(phi::product(x_dims) <= 0 || phi::product(labels_dims) <= 0)) {
check = false;
}

if (check) {
PADDLE_ENFORCE_EQ(
phi::slice_ddim(x_dims, 0, rank),
phi::slice_ddim(labels_dims, 0, rank),
phi::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension. But received: the shape of "
"Input(X) is [%s], the shape of Input(Label) is [%s].",
x_dims,
labels_dims));

if (pos_weight) {
auto weight_dims = pos_weight.dims();
PADDLE_ENFORCE_EQ(
phi::slice_ddim(weight_dims, 0, rank),
phi::slice_ddim(labels_dims, 0, rank),
phi::errors::InvalidArgument(
"Input(PosWeight) and Input(Label) shall have the same shape "
"But received: the shape of Input(PosWeight) is [%s], "
"the shape of Input(Label) is [%s].",
weight_dims,
labels_dims));
}
}

out->set_dims(x_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}

} // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -646,4 +646,12 @@ void MoeInferMeta(const MetaTensor& x,
const std::string& act_type,
MetaTensor* out);

void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

const MetaTensor& label,
const MetaTensor& pos_weight,
bool normalize,
int ignore_index,
MetaTensor* out,
MetaConfig config = MetaConfig());

} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,35 @@
namespace phi {

template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const DenseTensor& out_grad,
bool normalize,
int ignore_index,
DenseTensor* in_grad) {
void SigmoidCrossEntropyWithLogitsGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const paddle::optional<DenseTensor>& pos_weight,
const DenseTensor& out_grad,
bool normalize,
int ignore_index,
DenseTensor* in_grad) {
auto dx_data = dev_ctx.template Alloc<T>(in_grad);

int limit = in_grad->numel();
auto x_data = x.data<T>();
auto label_data = label.data<T>();
auto dout_data = out_grad.data<T>();
auto pos_weight_data =
(pos_weight.get_ptr() == nullptr ? nullptr
: pos_weight.get_ptr()->data<T>());

for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
T label = label_data[idx];
T dout = dout_data[idx];
T pos_weight_idx = pos_weight_data == nullptr ? 1 : pos_weight_data[idx];
if (static_cast<int>(label) == ignore_index) {
dx_data[idx] = static_cast<T>(0.);
} else {
T simoid_x = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-x));
T diff = simoid_x - label;
T diff = simoid_x * pos_weight_idx - label;
dx_data[idx] = dout * diff;
}
}
Expand Down
21 changes: 14 additions & 7 deletions paddle/phi/kernels/cpu/sigmoid_cross_entropy_with_logits_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,33 @@
namespace phi {

template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
bool normalize,
int ignore_index,
DenseTensor* out) {
void SigmoidCrossEntropyWithLogitsKernel(
const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const paddle::optional<DenseTensor>& pos_weight,
bool normalize,
int ignore_index,
DenseTensor* out) {
auto out_data = dev_ctx.template Alloc<T>(out);
int limit = out->numel();
auto x_data = x.data<T>();
auto label_data = label.data<T>();
auto pos_weight_data =
(pos_weight.get_ptr() == nullptr ? nullptr
: pos_weight.get_ptr()->data<T>());

for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
T label = label_data[idx];
if (static_cast<int>(label) == ignore_index) {
out_data[idx] = static_cast<T>(0.);
} else {
T pos_weight_idx = pos_weight_data == nullptr ? 1 : pos_weight_data[idx];
T term1 = (x > 0) ? x : 0;
T term2 = x * label;
T term3 = std::log(static_cast<T>(1) + std::exp(-std::abs(x)));
out_data[idx] = term1 - term2 + term3;
out_data[idx] = term1 - term2 + term3 * pos_weight_idx;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,50 @@ struct SigmoidBwdFunctor {
}
};

template <typename T>
struct SigmoidBwdPosWeightFunctor {
T ignore_index_;
T eps = static_cast<T>(1e-5);

HOSTDEVICE inline SigmoidBwdPosWeightFunctor(const T ignore_index)
: ignore_index_(ignore_index) {}

HOSTDEVICE inline phi::Array<T, 2> operator()(const T x,
const T label,
const T pos_weight,
const T dout) {
T counts;
T dx_data;

T diff = label - static_cast<T>(ignore_index_);
if ((diff > -eps) && (diff < eps)) {
dx_data = static_cast<T>(0.);
counts = 0;
} else {
T simoid_x =
static_cast<T>(1) / (static_cast<T>(1) + phi::funcs::real_exp(-x));
T diff = simoid_x * pos_weight - label;
dx_data = dout * diff;
counts = 1;
}
phi::Array<T, 2> outs;

outs[0] = dx_data;
outs[1] = counts;
return outs;
}
};

template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &label,
const DenseTensor &out_grad,
bool normalize,
int ignore_index,
DenseTensor *in_grad) {
void SigmoidCrossEntropyWithLogitsGradKernel(
const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &label,
const paddle::optional<DenseTensor> &pos_weight,
const DenseTensor &out_grad,
bool normalize,
int ignore_index,
DenseTensor *in_grad) {
auto dx_data = dev_ctx.template Alloc<T>(in_grad);

// Temporary memory
Expand All @@ -70,11 +106,19 @@ void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx,
dev_ctx.template Alloc<T>(counts_tensor);
counts_tensor->Resize(in_grad->dims());

std::vector<const DenseTensor *> ins = {&x, &label, &out_grad};
std::vector<DenseTensor *> outs = {in_grad, counts_tensor};
auto functor = SigmoidBwdFunctor<T>(ignore_index);
phi::funcs::ElementwiseKernel<T, decltype(functor), 2>(
dev_ctx, ins, &outs, functor);
if (pos_weight.get_ptr() == nullptr) {
std::vector<const DenseTensor *> ins = {&x, &label, &out_grad};
auto functor = SigmoidBwdFunctor<T>(ignore_index);
phi::funcs::ElementwiseKernel<T, decltype(functor), 2>(
dev_ctx, ins, &outs, functor);
} else {
std::vector<const DenseTensor *> ins = {
&x, &label, pos_weight.get_ptr(), &out_grad};
auto functor = SigmoidBwdPosWeightFunctor<T>(ignore_index);
phi::funcs::ElementwiseKernel<T, decltype(functor), 2>(
dev_ctx, ins, &outs, functor);
}
if (normalize) {
DenseTensor *norm_tensor = new DenseTensor();
norm_tensor->Resize({sizeof(T)});
Expand Down
Loading