Skip to content

Commit

Permalink
[PIR]Migrate CrossEntropyLoss into pir (#58519)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f authored Nov 2, 2023
1 parent 5cce0a5 commit f663ef1
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 25 deletions.
2 changes: 1 addition & 1 deletion python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2050,7 +2050,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
if epsilon > 1.0 or epsilon < 0.0:
raise ValueError("The value of epsilon must be between 0 and 1.")

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.label_smooth(label, prior_dist, float(epsilon))

check_variable_and_dtype(
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/functional/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ...base.data_feeder import check_variable_and_dtype
from ...base.layer_helper import LayerHelper
from ...common_ops_import import Variable
from ...framework import in_dynamic_mode, in_dynamic_or_pir_mode
from ...framework import in_dynamic_or_pir_mode

__all__ = []

Expand Down Expand Up @@ -89,7 +89,7 @@ def one_hot(x, num_classes, name=None):
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.one_hot(x, num_classes)
else:
check_variable_and_dtype(x, 'input', ['int32', 'int64'], 'one_hot_v2')
Expand Down
51 changes: 30 additions & 21 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from paddle.utils import deprecated

from ...base.data_feeder import check_variable_and_dtype
from ...base.framework import _current_expected_place
from ...base.framework import _current_expected_place, in_pir_mode
from ...base.layer_helper import LayerHelper
from ...common_ops_import import Variable
from ...tensor.manipulation import reshape
Expand Down Expand Up @@ -2935,24 +2935,31 @@ def cross_entropy(
['uint8', 'int8', 'int16', 'int32', 'int64', 'float32', 'float64'],
'softmax_cross_entropy',
)
attrs = {
'soft_label': soft_label,
'ignore_index': ignore_index,
'numeric_stable_mode': True,
'axis': axis,
'use_softmax': use_softmax,
}
helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=input.dtype)
out = helper.create_variable_for_type_inference(dtype=input.dtype)
if in_pir_mode():
softmax, out = _C_ops.cross_entropy_with_softmax(
input, label, soft_label, use_softmax, True, ignore_index, axis
)
else:
attrs = {
'soft_label': soft_label,
'ignore_index': ignore_index,
'numeric_stable_mode': True,
'axis': axis,
'use_softmax': use_softmax,
}
helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(
dtype=input.dtype
)
out = helper.create_variable_for_type_inference(dtype=input.dtype)

outputs = {'Softmax': softmax, 'Loss': out}
helper.append_op(
type='softmax_with_cross_entropy',
inputs={'Logits': input, 'Label': label},
outputs=outputs,
attrs=attrs,
)
outputs = {'Softmax': softmax, 'Loss': out}
helper.append_op(
type='softmax_with_cross_entropy',
inputs={'Logits': input, 'Label': label},
outputs=outputs,
attrs=attrs,
)

if weight is not None:
check_variable_and_dtype(
Expand Down Expand Up @@ -3036,19 +3043,21 @@ def cross_entropy(
if weight is None:
mask = paddle.cast(mask, dtype=out_sum.dtype)
count = paddle.sum(mask, name=name)
ret = out_sum / (count + (count == 0.0))
ret = out_sum / (count + paddle.equal(count, 0.0))
else:
mask = paddle.cast(mask, weight_gather_reshape.dtype)
weight_ignored = paddle.multiply(
mask, weight_gather_reshape
)
weight_sum = paddle.sum(weight_ignored, name=name)
ret = out_sum / (weight_sum + (weight_sum == 0.0))
ret = out_sum / (weight_sum + paddle.equal(weight_sum, 0.0))
return ret
elif weight is not None:
out_sum = paddle.sum(out, name=name)
total_weight = paddle.sum(weight_gather_reshape)
return out_sum / (total_weight + (total_weight == 0.0))
return out_sum / (
total_weight + paddle.equal(total_weight, 0.0)
)
else:
return paddle.mean(out, name=name)

Expand Down
33 changes: 32 additions & 1 deletion test/legacy_test/test_cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import paddle
from paddle import base
from paddle.base import Program, program_guard
from paddle.pir_utils import test_with_pir_api


def label_smooth(label, C, epsilon, is_onehot=True):
Expand Down Expand Up @@ -272,6 +273,7 @@ def test_softmax_with_cross_entropy(self):

# soft_label test start
# soft_label test 1
@test_with_pir_api
def test_cross_entropy_loss_soft_1d(self):
self.numeric_stable_mode = False
self.soft_label = True
Expand Down Expand Up @@ -360,6 +362,7 @@ def test_cross_entropy_loss_soft_1d(self):
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

# soft_label test 2
@test_with_pir_api
def test_cross_entropy_loss_soft_1d_weight(self):
self.numeric_stable_mode = False
self.soft_label = True
Expand Down Expand Up @@ -460,6 +463,7 @@ def test_cross_entropy_loss_soft_1d_weight(self):
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

# soft_label test 3
@test_with_pir_api
def test_cross_entropy_loss_soft_1d_mean(self):
self.numeric_stable_mode = False
self.soft_label = True
Expand Down Expand Up @@ -544,6 +548,7 @@ def test_cross_entropy_loss_soft_1d_mean(self):
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

# soft_label test 4
@test_with_pir_api
def test_cross_entropy_loss_soft_1d_weight_mean(self):
self.numeric_stable_mode = False
self.soft_label = True
Expand Down Expand Up @@ -634,6 +639,7 @@ def test_cross_entropy_loss_soft_1d_weight_mean(self):
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

# soft_label test 5
@test_with_pir_api
def test_cross_entropy_loss_soft_2d(self):
def inner_cross_entropy_loss_soft_2d(soft_label):
self.numeric_stable_mode = False
Expand Down Expand Up @@ -739,6 +745,7 @@ def inner_cross_entropy_loss_soft_2d(soft_label):
inner_cross_entropy_loss_soft_2d(False)

# soft_label test 6
@test_with_pir_api
def test_cross_entropy_loss_soft_2d_weight_mean(self):
self.numeric_stable_mode = False
self.soft_label = True
Expand Down Expand Up @@ -840,6 +847,7 @@ def test_cross_entropy_loss_soft_2d_weight_mean(self):
# soft_label test end

# label_smoothing test 1
@test_with_pir_api
def test_cross_entropy_loss_onehot_label_smoothing_1d(self):
self.numeric_stable_mode = False
self.soft_label = True
Expand Down Expand Up @@ -937,6 +945,7 @@ def test_cross_entropy_loss_onehot_label_smoothing_1d(self):
paddle.enable_static()

# label_smoothing test 2
@test_with_pir_api
def test_cross_entropy_loss_onehot_label_smoothing_1d_weight_mean(self):
self.numeric_stable_mode = False
self.soft_label = True
Expand Down Expand Up @@ -1036,6 +1045,7 @@ def test_cross_entropy_loss_onehot_label_smoothing_1d_weight_mean(self):
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

# label_smoothing test 3
@test_with_pir_api
def test_cross_entropy_loss_onehot_label_smoothing_2d(self):
self.numeric_stable_mode = False
self.soft_label = True
Expand Down Expand Up @@ -1143,6 +1153,7 @@ def test_cross_entropy_loss_onehot_label_smoothing_2d(self):
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

# label_smoothing test 4
@test_with_pir_api
def test_cross_entropy_loss_onehot_label_smoothing_2d_weight_mean(self):
self.numeric_stable_mode = False
self.soft_label = True
Expand Down Expand Up @@ -1253,6 +1264,7 @@ def test_cross_entropy_loss_onehot_label_smoothing_2d_weight_mean(self):
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

# label_smoothing test 5
@test_with_pir_api
def test_cross_entropy_loss_integer_label_smoothing_1d(self):
self.numeric_stable_mode = False
self.soft_label = True
Expand Down Expand Up @@ -1350,6 +1362,7 @@ def test_cross_entropy_loss_integer_label_smoothing_1d(self):
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

# label_smoothing test 6
@test_with_pir_api
def test_cross_entropy_loss_integer_label_smoothing_1d_weight_mean(self):
self.numeric_stable_mode = False
self.soft_label = True
Expand Down Expand Up @@ -1452,6 +1465,7 @@ def test_cross_entropy_loss_integer_label_smoothing_1d_weight_mean(self):
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

# label_smoothing test 7
@test_with_pir_api
def test_cross_entropy_loss_integer_label_smoothing_2d(self):
self.numeric_stable_mode = False
self.soft_label = True
Expand Down Expand Up @@ -1557,6 +1571,7 @@ def test_cross_entropy_loss_integer_label_smoothing_2d(self):
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

# label_smoothing test 8
@test_with_pir_api
def test_cross_entropy_loss_integer_label_smoothing_2d_weight_mean(self):
self.numeric_stable_mode = False
self.soft_label = True
Expand Down Expand Up @@ -1667,7 +1682,7 @@ def test_cross_entropy_loss_integer_label_smoothing_2d_weight_mean(self):
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

# label_smoothing test end

@test_with_pir_api
def test_cross_entropy_loss_1d_with_mean_ignore(self):
input_np = np.random.random([2, 4]).astype(self.dtype)
label_np = np.random.randint(0, 4, size=(2)).astype(np.int64)
Expand Down Expand Up @@ -1714,6 +1729,7 @@ def test_cross_entropy_loss_1d_with_mean_ignore(self):
np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05)
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

@test_with_pir_api
def test_cross_entropy_loss_1d_with_mean_ignore_negative(self):
N = 100
C = 200
Expand Down Expand Up @@ -1763,6 +1779,7 @@ def test_cross_entropy_loss_1d_with_mean_ignore_negative(self):
np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05)
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

@test_with_pir_api
def test_cross_entropy_loss_1d_with_weight_mean_ignore(self):
N = 100
C = 200
Expand Down Expand Up @@ -1846,6 +1863,7 @@ def test_cross_entropy_loss_1d_with_weight_mean_ignore_exceedlabel(self):

np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

@test_with_pir_api
def test_cross_entropy_loss_1d_with_weight_mean(self):
input_np = np.random.random([2, 4]).astype(self.dtype)
label_np = np.random.randint(0, 4, size=(2)).astype(np.int64)
Expand Down Expand Up @@ -1901,6 +1919,7 @@ def test_cross_entropy_loss_1d_with_weight_mean(self):
np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05)
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

@test_with_pir_api
def test_cross_entropy_loss_1d_with_weight_sum(self):
input_np = np.random.random([100, 200]).astype(self.dtype) # N,C
label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) # N,1
Expand Down Expand Up @@ -1954,6 +1973,7 @@ def test_cross_entropy_loss_1d_with_weight_sum(self):
np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05)
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

@test_with_pir_api
def test_cross_entropy_loss_1d_with_weight_none(self):
input_np = np.random.random([100, 200]).astype(self.dtype) # N,C
label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) # N,1
Expand Down Expand Up @@ -2011,6 +2031,7 @@ def test_cross_entropy_loss_1d_with_weight_none(self):
np.testing.assert_allclose(static_ret, expected, rtol=1e-05)
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

@test_with_pir_api
def test_cross_entropy_loss_1d_with_weight_none_func(self):
input_np = np.random.random([100, 200]).astype(self.dtype) # N,C
label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) # N
Expand Down Expand Up @@ -2064,6 +2085,7 @@ def test_cross_entropy_loss_1d_with_weight_none_func(self):
np.testing.assert_allclose(static_ret, expected, rtol=1e-05)
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

@test_with_pir_api
def test_cross_entropy_loss_1d_mean(self):
input_np = np.random.random([100, 200]).astype(self.dtype) # N,C
label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) # N,1
Expand Down Expand Up @@ -2102,6 +2124,7 @@ def test_cross_entropy_loss_1d_mean(self):
np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05)
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

@test_with_pir_api
def test_cross_entropy_loss_1d_sum(self):
input_np = np.random.random([100, 200]).astype(self.dtype) # N,C
label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) # N,1
Expand Down Expand Up @@ -2144,6 +2167,7 @@ def test_cross_entropy_loss_1d_sum(self):
np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05)
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

@test_with_pir_api
def test_cross_entropy_loss_1d_none(self):
input_np = np.random.random([100, 200]).astype(self.dtype) # N,C
label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) # N,1
Expand Down Expand Up @@ -2188,6 +2212,7 @@ def test_cross_entropy_loss_1d_none(self):
np.testing.assert_allclose(static_ret, expected, rtol=1e-05)
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

@test_with_pir_api
def test_cross_entropy_loss_2d_with_weight_none(self):
input_np = np.random.random(size=(2, 2, 2, 3)).astype(
self.dtype
Expand Down Expand Up @@ -2250,6 +2275,7 @@ def test_cross_entropy_loss_2d_with_weight_none(self):
np.testing.assert_allclose(static_ret, expected, rtol=1e-05)
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

@test_with_pir_api
def test_cross_entropy_loss_2d_with_weight_axis_change_mean(self):
input_np = np.random.random(size=(2, 3, 2, 2)).astype(
self.dtype
Expand Down Expand Up @@ -2341,6 +2367,7 @@ def test_cross_entropy_loss_2d_with_weight_mean_ignore_exceedlabel(self):
)[0]
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

@test_with_pir_api
def test_cross_entropy_loss_2d_with_weight_mean(self):
input_np = np.random.random(size=(2, 2, 2, 3)).astype(
self.dtype
Expand Down Expand Up @@ -2400,6 +2427,7 @@ def test_cross_entropy_loss_2d_with_weight_mean(self):
np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05)
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

@test_with_pir_api
def test_cross_entropy_loss_2d_with_weight_sum(self):
input_np = np.random.random(size=(2, 2, 2, 3)).astype(
self.dtype
Expand Down Expand Up @@ -2460,6 +2488,7 @@ def test_cross_entropy_loss_2d_with_weight_sum(self):
np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05)
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

@test_with_pir_api
def test_cross_entropy_loss_2d_none(self):
input_np = np.random.random(size=(2, 2, 2, 3)).astype(
self.dtype
Expand Down Expand Up @@ -2513,6 +2542,7 @@ def test_cross_entropy_loss_2d_none(self):
np.testing.assert_allclose(static_ret, expected, rtol=1e-05)
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

@test_with_pir_api
def test_cross_entropy_loss_2d_mean(self):
input_np = np.random.random(size=(2, 2, 2, 3)).astype(
self.dtype
Expand Down Expand Up @@ -2567,6 +2597,7 @@ def test_cross_entropy_loss_2d_mean(self):
np.testing.assert_allclose(static_ret[0], expected, rtol=1e-05)
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

@test_with_pir_api
def test_cross_entropy_loss_2d_sum(self):
input_np = np.random.random(size=(2, 2, 2, 3)).astype(
self.dtype
Expand Down

0 comments on commit f663ef1

Please sign in to comment.