Skip to content

Commit

Permalink
【PIR API adaptor No.247】python/paddle/text/viterbi_decode.py (#58785)
Browse files Browse the repository at this point in the history
  • Loading branch information
Liyulingyue authored Nov 10, 2023
1 parent ec729e2 commit 7007d8f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 2 additions & 2 deletions python/paddle/text/viterbi_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from paddle import _C_ops

from ..base.data_feeder import check_type, check_variable_and_dtype
from ..base.framework import in_dygraph_mode
from ..base.framework import in_dynamic_or_pir_mode
from ..base.layer_helper import LayerHelper
from ..nn import Layer

Expand Down Expand Up @@ -64,7 +64,7 @@ def viterbi_decode(
[[0, 0],
[1, 1]])
"""
if in_dygraph_mode():
if in_dynamic_or_pir_mode():
return _C_ops.viterbi_decode(
potentials, transition_params, lengths, include_bos_eos_tag
)
Expand Down
4 changes: 3 additions & 1 deletion test/legacy_test/test_viterbi_decode_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()

Expand Down Expand Up @@ -99,7 +100,7 @@ def setUp(self):
self.outputs = {'Scores': scores, 'Path': path}

def test_output(self):
self.check_output()
self.check_output(check_pir=True)


class TestViterbiAPI(unittest.TestCase):
Expand All @@ -121,6 +122,7 @@ def setUp(self):
decoder = Decoder(self.transitions, self.use_tag)
self.scores, self.path = decoder(self.input, self.length)

@test_with_pir_api
def check_static_result(self, place):
bz, length, ntags = self.bz, self.len, self.ntags
with base.program_guard(base.Program(), base.Program()):
Expand Down

0 comments on commit 7007d8f

Please sign in to comment.