Skip to content

Commit

Permalink
Fix bug of c_softmax_with_cross_entropy_op_xpu_op (#52296)
Browse files Browse the repository at this point in the history
* Support ignore_index for c_softmax_with_cross_entropy_op.

* Polish code. Remove useless comments and add Testcase.

* Polish code for TestCase.

* Polish code.

* Polish code style.

* Polish code.

* Change loss calculation formula and ignore_index dtype.

* Polish TestCase.

* Fix bug of c_softmax_with_cross_entropy_op_xpu_op. Attribute 'ignore_index'
dtype is int64_t.
  • Loading branch information
GhostScreaming authored Mar 30, 2023
1 parent dfa893f commit 8ef9708
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ template <typename DeviceContext, typename T>
class CSoftmaxWithCrossEntropyOp : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const int ignore_index = ctx.Attr<int>("ignore_index");
const int64_t ignore_index = ctx.Attr<int64_t>("ignore_index");
PADDLE_ENFORCE_LT(ignore_index,
0,
platform::errors::InvalidArgument(
"When SoftmaxWithCrossEntropy run on XPU, "
"ignore_index should be <=0, however it's %d",
"ignore_index should be <=0, however it's %ld",
ignore_index));
const int rid = ctx.Attr<int>("ring_id");
auto map = distributed::ProcessGroupMapFromGid::getInstance();
Expand Down Expand Up @@ -460,12 +460,12 @@ class CSoftmaxWithCrossEntropyGrad : public framework::OpKernel<T> {
context.Output<phi::DenseTensor>(framework::GradVarName("Logits"));
const phi::DenseTensor* softmax =
context.Input<phi::DenseTensor>("Softmax");
const int ignore_index = context.Attr<int>("ignore_index");
const int64_t ignore_index = context.Attr<int64_t>("ignore_index");
PADDLE_ENFORCE_LT(ignore_index,
0,
platform::errors::InvalidArgument(
"When SoftmaxWithCrossEntropy run on XPU, "
"ignore_index should be <=0, however it's %d",
"ignore_index should be <=0, however it's %ld",
ignore_index));
const int rank = context.Attr<int>("rank");
auto& dev_ctx = context.template device_context<DeviceContext>();
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/fleet/layers/mpu/mp_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ class ParallelCrossEntropy(paddle.nn.Layer):
mp_group(Group): The tensor parallel group.
name(str, optional): Normally there is no need for user to set this parameter.
For detailed information, please refer to :ref:`api_guide_Name` .
ignore_index (int, optional): Specifies a target value that is ignored and
ignore_index (long int, optional): Specifies a target value that is ignored and
does not contribute to the loss. A negative value means that no label value
needs to be ignored. Default is -100 .
Expand Down

0 comments on commit 8ef9708

Please sign in to comment.