Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add functional ctc_loss and CTCLoss class. #26384

Merged
merged 5 commits into from
Aug 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 125 additions & 23 deletions python/paddle/fluid/tests/unittests/test_warpctc_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,25 @@
from test_softmax_op import stable_softmax
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import paddle
import paddle.nn.functional as F

CUDA_BLOCK_SIZE = 512


class CTCForward(object):
def __init__(self, softmax, softmax_lod, labels, labels_lod, blank,
norm_by_times):
def __init__(self, softmax, softmax_lod, labels, labels_lod, num_classes,
batch_size, blank, norm_by_times):
self.softmax = softmax
self.softmax_lod = softmax_lod
assert labels.shape[1] == 1
self.labels = labels
self.labels_lod = labels_lod
self.blank = blank
self.norm_by_times = norm_by_times

self.level = 0
self.num_classes = softmax.shape[1]
self.batch_size = len(softmax_lod[self.level])
assert self.batch_size == len(labels_lod[self.level])
self.num_classes = num_classes
self.batch_size = batch_size

self.loss = np.zeros([self.batch_size, 1], dtype="float32")
self.gradient = np.zeros(self.softmax.shape, dtype="float32")
Expand Down Expand Up @@ -163,17 +163,25 @@ def forward(self):
softmax_offset = 0
labels_offset = 0
for i in range(self.batch_size):
softmax_start_i = softmax_offset
softmax_end_i = softmax_offset + self.softmax_lod[self.level][i]
labels_start_i = labels_offset
labels_end_i = labels_offset + self.labels_lod[self.level][i]

softmax_a_sequence = self.softmax[softmax_start_i:softmax_end_i, :]
labels_a_sequence = self.labels[labels_start_i:labels_end_i, :]
self.loss[i] = self.forward_a_sequence(softmax_a_sequence,
labels_a_sequence)
softmax_offset += self.softmax_lod[self.level][i]
labels_offset += self.labels_lod[self.level][i]
if self.labels.shape[1] == 1:
softmax_start_i = softmax_offset
softmax_end_i = softmax_offset + self.softmax_lod[self.level][i]
labels_start_i = labels_offset
labels_end_i = labels_offset + self.labels_lod[self.level][i]

softmax_a_sequence = self.softmax[softmax_start_i:
softmax_end_i, :]
labels_a_sequence = self.labels[labels_start_i:labels_end_i, :]
self.loss[i] = self.forward_a_sequence(softmax_a_sequence,
labels_a_sequence)
softmax_offset += self.softmax_lod[self.level][i]
labels_offset += self.labels_lod[self.level][i]
else:
softmax_a_sequence = self.softmax[:self.softmax_lod[i], i, :]
labels_a_sequence = self.labels[:self.labels_lod[i], :]
self.loss[i] = self.forward_a_sequence(softmax_a_sequence,
labels_a_sequence)

return self.loss


Expand Down Expand Up @@ -201,7 +209,8 @@ def setUp(self):
dtype="int32")

ctc = CTCForward(softmax, self.logits_lod, labels, self.labels_lod,
self.blank, self.norm_by_times)
self.num_classes, self.batch_size, self.blank,
self.norm_by_times)
loss = ctc.forward()

max_sequence_length = 0
Expand All @@ -223,7 +232,7 @@ def setUp(self):
}

def test_check_output(self):
self.check_output(check_dygraph=False)
self.check_output()

def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
Expand All @@ -237,7 +246,7 @@ def config(self):
self.num_classes = CUDA_BLOCK_SIZE + 2
self.logits_lod = [[4, 1, 3, 3]]
self.labels_lod = [[3, 1, 4, 4]]
self.blank = 0
self.blank = self.num_classes - 1
self.norm_by_times = False


Expand Down Expand Up @@ -267,7 +276,8 @@ def setUp(self):
dtype="int32")

ctc = CTCForward(softmax, self.logits_lod, labels, self.labels_lod,
self.blank, self.norm_by_times)
self.num_classes, self.batch_size, self.blank,
self.norm_by_times)
loss = ctc.forward()

max_sequence_length = 0
Expand Down Expand Up @@ -317,7 +327,7 @@ def setUp(self):
}

def test_check_output(self):
self.check_output(check_dygraph=False)
self.check_output()

def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
Expand All @@ -333,7 +343,7 @@ def config(self):
self.labels_lod = [[3, 1, 4, 4]]
self.logits_length = np.array([4, 1, 3, 3], dtype=np.int64)
self.labels_length = np.array([3, 1, 4, 4], dtype=np.int64)
self.blank = 0
self.blank = self.num_classes - 1
self.norm_by_times = False


Expand Down Expand Up @@ -389,5 +399,97 @@ def test_label_len_Variable():
self.assertRaises(TypeError, test_label_len_Variable)


class TestCTCLossAPICase(unittest.TestCase):
def test_functinal_api(self):
self.batch_size = 4
self.num_classes = CUDA_BLOCK_SIZE + 2
self.logits_length = np.array([4, 1, 3, 3], dtype=np.int64)
self.labels_length = np.array([3, 1, 4, 4], dtype=np.int64)
self.blank = self.num_classes - 1
self.norm_by_times = False

logits = np.random.uniform(0.1, 1.0, [
max(self.logits_length), self.batch_size, self.num_classes
]).astype("float32")
softmax = np.apply_along_axis(stable_softmax, -1, logits)
# labels should not be blank
labels = np.random.randint(
0,
self.num_classes - 1, [self.batch_size, max(self.labels_length)],
dtype="int32")

ctc = CTCForward(softmax, self.logits_length, labels,
self.labels_length, self.num_classes, self.batch_size,
self.blank, self.norm_by_times)
loss_np = ctc.forward()

paddle.disable_static()
softmax = paddle.to_variable(logits)
labels = paddle.to_variable(labels)
logits_length = paddle.to_variable(self.logits_length)
labels_length = paddle.to_variable(self.labels_length)
loss_pd_mean = F.ctc_loss(
softmax,
labels,
logits_length,
labels_length,
blank=self.blank,
reduction='mean')
loss_pd_mean = loss_pd_mean.numpy()

loss_pd_sum = F.ctc_loss(
softmax,
labels,
logits_length,
labels_length,
blank=self.blank,
reduction='sum')
loss_pd_sum = loss_pd_sum.numpy()
paddle.enable_static()
loss_np = np.squeeze(loss_np, axis=-1)
loss_np_mean = (loss_np / labels_length.numpy()).mean()
loss_np_sum = loss_np.sum()

self.assertTrue(np.allclose(loss_pd_mean, loss_np_mean, atol=1))
self.assertTrue(np.allclose(loss_pd_sum, loss_np_sum, atol=1))

def test_class_api(self):
self.batch_size = 3
self.num_classes = 15
self.logits_length = np.array([3, 3, 3], dtype=np.int64)
self.labels_length = np.array([0, 1, 2], dtype=np.int64)
self.blank = 0
self.norm_by_times = False

logits = np.random.uniform(0.1, 1.0, [
max(self.logits_length), self.batch_size, self.num_classes
]).astype("float32")
softmax = np.apply_along_axis(stable_softmax, -1, logits)
# labels should not be blank
labels = np.random.randint(
1,
self.num_classes, [self.batch_size, max(self.labels_length)],
dtype="int32")

ctc = CTCForward(softmax, self.logits_length, labels,
self.labels_length, self.num_classes, self.batch_size,
self.blank, self.norm_by_times)
loss_np = ctc.forward()

paddle.disable_static()
softmax = paddle.to_variable(logits)
labels = paddle.to_variable(labels)
logits_length = paddle.to_variable(self.logits_length)
labels_length = paddle.to_variable(self.labels_length)

loss_pd = paddle.nn.CTCLoss(self.blank, 'none')(
softmax, labels, logits_length, labels_length)
loss_pd = loss_pd.numpy()
paddle.enable_static()
loss_np = np.squeeze(loss_np, axis=-1)

self.assertTrue(np.allclose(loss_pd, loss_np, atol=1))


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
from .layer.loss import BCELoss #DEFINE_ALIAS
from .layer.loss import KLDivLoss #DEFINE_ALIAS
from .layer.loss import MarginRankingLoss #DEFINE_ALIAS
from .layer.loss import CTCLoss #DEFINE_ALIAS
from .layer.loss import SmoothL1Loss #DEFINE_ALIAS
from .layer.norm import BatchNorm #DEFINE_ALIAS
from .layer.norm import SyncBatchNorm #DEFINE_ALIAS
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
__all__ += extension.__all__
from . import common
__all__ += common.__all__
from . import loss
__all__ += loss.__all__
from .activation import brelu #DEFINE_ALIAS
from .activation import elu #DEFINE_ALIAS
from .activation import erf #DEFINE_ALIAS
Expand Down Expand Up @@ -147,6 +149,7 @@
from .loss import square_error_cost #DEFINE_ALIAS
from .loss import ssd_loss #DEFINE_ALIAS
from .loss import teacher_student_sigmoid_loss #DEFINE_ALIAS
from .loss import ctc_loss #DEFINE_ALIAS
# from .norm import batch_norm #DEFINE_ALIAS
# from .norm import data_norm #DEFINE_ALIAS
# from .norm import group_norm #DEFINE_ALIAS
Expand Down
102 changes: 101 additions & 1 deletion python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.

import paddle
from ...fluid.layer_helper import LayerHelper
from ...fluid.data_feeder import check_variable_and_dtype
import paddle.fluid as fluid

# TODO: define loss functions of neural network
import numpy as np
Expand Down Expand Up @@ -70,7 +73,8 @@
'softmax_with_cross_entropy',
'square_error_cost',
'ssd_loss',
'teacher_student_sigmoid_loss'
'teacher_student_sigmoid_loss',
'ctc_loss',
]


Expand Down Expand Up @@ -787,6 +791,102 @@ def mse_loss(input, label, reduction='mean', name=None):
name=name)


def ctc_loss(log_probs,
labels,
input_lengths,
label_lengths,
blank=0,
reduction='mean'):
"""

An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc)
to compute Connectionist Temporal Classification (CTC) loss.
It can be aliased as softmax with CTC, since a native softmax activation
is interated to the Warp-CTC library to normalize values for each row of the input tensor.

Parameters:
log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type must be float32.
labels (Tensor): The ground truth sequence with padding, which must be a 3-D Tensor. The tensor shape is [batch_size, max_label_length], where max_label_length is the longest length of label sequence. The data type must be int32.
input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64.
label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64.
blank (int, optional): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). The data type must be int32. Default is 0.
reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``.

Returns:
Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``.

Examples:

.. code-block:: python

# declarative mode
import paddle.nn.functional as F
import numpy as np
import paddle

# length of the longest logit sequence
max_seq_length = 4
#length of the longest label sequence
max_label_length = 3
# number of logit sequences
batch_size = 2
# class num
class_num = 3

np.random.seed(1)
log_probs = np.array([[[4.17021990e-01, 7.20324516e-01, 1.14374816e-04],
[3.02332580e-01, 1.46755889e-01, 9.23385918e-02]],

[[1.86260208e-01, 3.45560730e-01, 3.96767467e-01],
[5.38816750e-01, 4.19194520e-01, 6.85219526e-01]],

[[2.04452246e-01, 8.78117442e-01, 2.73875929e-02],
[6.70467496e-01, 4.17304814e-01, 5.58689833e-01]],

[[1.40386939e-01, 1.98101491e-01, 8.00744593e-01],
[9.68261600e-01, 3.13424170e-01, 6.92322612e-01]],

[[8.76389146e-01, 8.94606650e-01, 8.50442126e-02],
[3.90547849e-02, 1.69830427e-01, 8.78142476e-01]]]).astype("float32")
labels = np.array([[1, 2, 2],
[1, 2, 2]]).astype("int32")
input_lengths = np.array([5, 5]).astype("int64")
label_lengths = np.array([3, 3]).astype("int64")

paddle.disable_static()
log_probs = paddle.to_variable(log_probs)
labels = paddle.to_variable(labels)
input_lengths = paddle.to_variable(input_lengths)
label_lengths = paddle.to_variable(label_lengths)

loss = F.ctc_loss(log_probs, labels,
input_lengths,
label_lengths,
blank=0,
reduction='none')
print(loss.numpy()) #[3.9179852 2.9076521]

loss = F.ctc_loss(log_probs, labels,
input_lengths,
label_lengths,
blank=0,
reduction='mean')
print(loss.numpy()) #[1.1376063]

"""

loss_out = fluid.layers.warpctc(log_probs, labels, blank, False,
input_lengths, label_lengths)

loss_out = fluid.layers.squeeze(loss_out, [-1])
assert reduction in ['mean', 'sum', 'none']
if reduction == 'mean':
loss_out = paddle.mean(loss_out / label_lengths)
elif reduction == 'sum':
loss_out = paddle.sum(loss_out)
return loss_out


def cross_entropy(input,
label,
weight=None,
Expand Down
1 change: 1 addition & 0 deletions python/paddle/nn/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from .loss import BCELoss #DEFINE_ALIAS
from .loss import KLDivLoss #DEFINE_ALIAS
from .loss import MarginRankingLoss #DEFINE_ALIAS
from .loss import CTCLoss #DEFINE_ALIAS
from .loss import SmoothL1Loss #DEFINE_ALIAS
from .norm import BatchNorm #DEFINE_ALIAS
from .norm import SyncBatchNorm #DEFINE_ALIAS
Expand Down
Loading