Skip to content

Commit

Permalink
【PIR API adaptor No.248、249】 Migrate CTCLoss/RNNTLossinto pir (Paddle…
Browse files Browse the repository at this point in the history
  • Loading branch information
DrRyanHuang authored and SecretXV committed Nov 28, 2023
1 parent 1280620 commit 7a6379c
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 26 deletions.
4 changes: 2 additions & 2 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1893,7 +1893,7 @@ def warpctc(
input_length=None,
label_length=None,
):
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
if input_length is None or label_length is None:
raise ValueError(
"input_length and label_length must not be None in dygraph mode!"
Expand Down Expand Up @@ -2017,7 +2017,7 @@ def rnnt_loss(
def warprnnt(
input, label, input_length, label_length, blank=0, fastemit_lambda=0.001
):
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
loss_out = _C_ops.warprnnt(
input,
label,
Expand Down
8 changes: 1 addition & 7 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3500,13 +3500,7 @@ def lstsq(x, y, rcond=None, driver=None, name=None):
x, y, rcond, driver
)
if driver == "gels":
if in_dynamic_mode():
rank = paddle.empty(shape=[0], dtype=paddle.int32)

else:
rank = paddle.empty(
shape=[0], dtype=paddle.base.core.DataType.INT32
)
rank = paddle.empty(shape=[0], dtype="int32")
singular_values = paddle.empty(shape=[0], dtype=x.dtype)
elif driver == "gelsy":
singular_values = paddle.empty(shape=[0], dtype=x.dtype)
Expand Down
20 changes: 13 additions & 7 deletions test/legacy_test/test_warpctc_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import paddle
import paddle.nn.functional as F
from paddle.base import Program, core, program_guard
from paddle.base import core

CUDA_BLOCK_SIZE = 32

Expand Down Expand Up @@ -394,7 +394,7 @@ def setUp(self):
}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
Expand All @@ -404,13 +404,15 @@ def test_check_grad(self):
"Loss",
max_relative_error=0.009,
check_dygraph=False,
check_pir=True,
)
else:
self.check_grad(
["Logits"],
"Loss",
max_relative_error=0.007,
check_dygraph=False,
check_pir=True,
)


Expand Down Expand Up @@ -516,17 +518,21 @@ def setUp(self):
}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
self.check_grad(["Logits"], "Loss")
self.check_grad(["Logits"], "Loss", check_pir=True)


class TestWarpCTCOpError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(
main_program=main_program, startup_program=startup_program
):
logits = paddle.static.data(
name='logits', shape=[5, 16, 6], dtype='float32'
)
Expand Down Expand Up @@ -660,7 +666,7 @@ def test_class_api(self):
np.testing.assert_allclose(loss_pd, loss_np, rtol=1e-05, atol=1)

def test_eager_ctcloss(self):
def test_functinal_api():
def test_functional_api():
self.batch_size = 4
self.num_classes = CUDA_BLOCK_SIZE + 2
self.logits_length = np.array([4, 1, 3, 3], dtype=np.int64)
Expand Down Expand Up @@ -730,7 +736,7 @@ def test_functinal_api():
loss_pd_sum, loss_np_sum, rtol=1e-05, atol=1
)

test_functinal_api()
test_functional_api()


if __name__ == "__main__":
Expand Down
14 changes: 4 additions & 10 deletions test/legacy_test/test_warprnnt_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,7 @@ def test_check_grad(self):
self.outputs["warprnntgrad"] = self.gradient
if core.is_compiled_with_rocm():
self.check_grad(
["input"],
"loss",
numeric_grad_delta=0.009,
["input"], "loss", numeric_grad_delta=0.009, check_pir=True
)
else:
self.check_grad(
Expand All @@ -246,22 +244,18 @@ def test_check_grad(self):
class TestWarpRNNTFP64Op(TestWarpRNNTOp):
def test_check_output(self):
self.acts.astype(np.float64)
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.acts.astype(np.float64)
self.outputs["warprnntgrad"] = self.gradient
if core.is_compiled_with_rocm():
self.check_grad(
["input"],
"loss",
numeric_grad_delta=0.009,
["input"], "loss", numeric_grad_delta=0.009, check_pir=True
)
else:
self.check_grad(
["input"],
"loss",
numeric_grad_delta=0.009,
["input"], "loss", numeric_grad_delta=0.009, check_pir=True
)


Expand Down

0 comments on commit 7a6379c

Please sign in to comment.