diff --git a/nemo/collections/asr/losses/rnnt.py b/nemo/collections/asr/losses/rnnt.py index 430d179fdab2..6c275575206b 100644 --- a/nemo/collections/asr/losses/rnnt.py +++ b/nemo/collections/asr/losses/rnnt.py @@ -34,6 +34,7 @@ import torch from omegaconf import DictConfig, OmegaConf +from nemo.collections.asr.losses.rnnt_pytorch import RNNTLossPytorch from nemo.core.classes import Loss, typecheck from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType from nemo.core.utils.numba_utils import NUMBA_INSTALLATION_MESSAGE @@ -87,6 +88,13 @@ class RNNTLossConfig: is_available=NUMBA_RNNT_AVAILABLE, installation_msg=NUMBA_INSTALLATION_MESSAGE, ), + "pytorch": RNNTLossConfig( + loss_name="pytorch", + lib_name="torch", + min_version='0.0', + is_available=True, + installation_msg="Pure Pytorch implementation of RNN-T loss. Slow and for debugging purposes only.", + ), } RNNT_LOSS_RESOLVER['default'] = RNNT_LOSS_RESOLVER['warprnnt_numba'] @@ -165,6 +173,10 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) loss_func = RNNTLossNumba(blank=blank_idx, reduction='none', fastemit_lambda=fastemit_lambda, clamp=clamp) _warn_unused_additional_kwargs(loss_name, loss_kwargs) + elif loss_name == 'pytorch': + loss_func = RNNTLossPytorch(blank=blank_idx, reduction='none') + _warn_unused_additional_kwargs(loss_name, loss_kwargs) + else: raise ValueError( f"Invalid value of `loss_name`: {loss_name}. Allowed loss names are :" f"{loss_function_names}" diff --git a/nemo/collections/asr/losses/rnnt_pytorch.py b/nemo/collections/asr/losses/rnnt_pytorch.py new file mode 100644 index 000000000000..8400edac4e23 --- /dev/null +++ b/nemo/collections/asr/losses/rnnt_pytorch.py @@ -0,0 +1,112 @@ +# ! /usr/bin/python +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from nemo.core.classes import Loss +from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType + + +class RNNTLossPytorch(Loss): + @property + def input_types(self): + """Input types definitions for CTCLoss. + """ + return { + "acts": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()), + "labels": NeuralType(('B', 'T'), LabelsType()), + "act_lens": NeuralType(tuple('B'), LengthsType()), + "label_lens": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Output types definitions for CTCLoss. + loss: + NeuralType(None) + """ + return {"loss": NeuralType(elements_type=LossType())} + + def __init__(self, blank, reduction): + super().__init__() + self.blank = blank + self.reduction = reduction + + def forward(self, acts, labels, act_lens, label_lens): + acts = torch.log_softmax(acts, -1) + forward_logprob = self.compute_forward_prob(acts, labels, act_lens, label_lens) + losses = -forward_logprob + if self.reduction == 'mean_batch': + losses = losses.mean() # global batch size average + elif self.reduction == 'mean': + losses = torch.div(losses, label_lens).mean() + elif self.reduction == 'sum': + losses = losses.sum() + elif self.reduction == 'mean_volume': + losses = losses.sum() / label_lens.sum() # same as above but longer samples weigh more + + return losses + + def compute_forward_prob(self, acts, labels, act_lens, label_lens): + B, T, U, _ = acts.shape + + log_alpha = torch.zeros(B, T, U) + log_alpha = log_alpha.to(acts.device) + + for t in range(T): + for u in range(U): + if u == 0: + if t == 0: + # this is the base case: (t=0, u=0) with log-alpha = 0. + log_alpha[:, t, u] = 0.0 + else: + # this is case for (t = 0, u > 0), reached by (t, u - 1) + # emitting a blank symbol. + log_alpha[:, t, u] = log_alpha[:, t - 1, u] + acts[:, t - 1, 0, self.blank] + else: + if t == 0: + # in case of (u > 0, t = 0), this is only reached from + # (t, u - 1) with a label emission. + gathered = torch.gather( + acts[:, t, u - 1], dim=1, index=labels[:, u - 1].view(-1, 1).type(torch.int64) + ).reshape(-1) + log_alpha[:, t, u] = log_alpha[:, t, u - 1] + gathered.to(log_alpha.device) + else: + # here both t and u are > 0, this state is reachable + # with two possibilities: (t - 1, u) with a blank emission + # or (t, u - 1) with a label emission. + log_alpha[:, t, u] = torch.logsumexp( + torch.stack( + [ + log_alpha[:, t - 1, u] + acts[:, t - 1, u, self.blank], + log_alpha[:, t, u - 1] + + torch.gather( + acts[:, t, u - 1], dim=1, index=labels[:, u - 1].view(-1, 1).type(torch.int64) + ).reshape(-1), + ] + ), + dim=0, + ) + + log_probs = [] + for b in range(B): + # here we need to add the final blank emission weights. + to_append = ( + log_alpha[b, act_lens[b] - 1, label_lens[b]] + acts[b, act_lens[b] - 1, label_lens[b], self.blank] + ) + log_probs.append(to_append) + log_prob = torch.stack(log_probs) + + return log_prob diff --git a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py index d303e5355bf9..e1b0ce314b38 100644 --- a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py +++ b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py @@ -16,6 +16,7 @@ import pytest import torch +from nemo.collections.asr.losses.rnnt import RNNTLossPytorch from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_numpy import RNNTLoss as RNNTLoss_Numpy from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_pytorch import RNNTLossNumba from nemo.core.utils import numba_utils @@ -85,6 +86,9 @@ def test_case_small(self, device): fn_np = RNNTLoss_Numpy() np_cost, np_grads = wrap_and_call(fn_np, acts, labels, device) + fn_ag = RNNTLossPytorch(blank=0, reduction='sum') # ag for automatic gradient computation + ag_cost, ag_grads = wrap_and_call(fn_ag, acts, labels, device) + expected_cost = 4.495666 expected_grads = np.array( [ @@ -109,6 +113,9 @@ def test_case_small(self, device): assert np.allclose(pt_cost, np_cost, rtol=1e-6), "small_test costs mismatch." assert np.allclose(pt_grads, np_grads), "small_test gradient mismatch." + assert np.allclose(ag_cost, np_cost, rtol=1e-6), "small_test costs mismatch." + assert np.allclose(ag_grads, np_grads), "small_test gradient mismatch." + @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) def test_case_small_random(self, device): @@ -125,9 +132,15 @@ def test_case_small_random(self, device): fn_np = RNNTLoss_Numpy() np_cost, np_grads = wrap_and_call(fn_np, acts, labels, device) + fn_ag = RNNTLossPytorch(blank=0, reduction='sum') # ag for automatic gradient computation + ag_cost, ag_grads = wrap_and_call(fn_ag, acts, labels, device) + assert np.allclose(pt_cost, np_cost, rtol=1e-6), "small_random_test costs mismatch." assert np.allclose(pt_grads, np_grads), "small_random_test gradient mismatch." + assert np.allclose(pt_cost, ag_cost, rtol=1e-6), "small_random_test costs mismatch." + assert np.allclose(pt_grads, ag_grads), "small_random_test gradient mismatch." + @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) @pytest.mark.parametrize('fastemit_lambda', [1.0, 0.01, 0.00001]) @@ -259,12 +272,18 @@ def test_case_big_tensor(self, device): fn_np = RNNTLoss_Numpy() np_costs, np_grads = wrap_and_call(fn_np, activations, labels, device) + fn_ag = RNNTLossPytorch(blank=0, reduction='sum') + ag_costs, ag_grads = wrap_and_call(fn_ag, activations, labels, device) + assert np.allclose(pt_costs, sum(expected_costs)), "big_test average costs mismatch." assert np.allclose(pt_grads, expected_grads, rtol=1e-3), "big_test grads for average cost mismatch." assert np.allclose(pt_costs, np_costs), "big_test average costs mismatch." assert np.allclose(pt_grads, np_grads, rtol=1e-3), "big_test grads for average cost mismatch." + assert np.allclose(pt_costs, ag_costs), "big_test average costs mismatch." + assert np.allclose(pt_grads, ag_grads, rtol=1e-3), "big_test grads for average cost mismatch." + @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) def test_case_large_random(self, device): @@ -286,8 +305,13 @@ def test_case_large_random(self, device): fn_np = RNNTLoss_Numpy() np_cost, np_grads = wrap_and_call(fn_np, acts, labels, device) + fn_ag = RNNTLossPytorch(blank=0, reduction='sum') + ag_cost, ag_grads = wrap_and_call(fn_ag, acts, labels, device) + assert np.allclose(pt_cost, np_cost, atol=1e-5, rtol=1e-3), "large_random_test costs mismatch." + assert np.allclose(ag_cost, np_cost, atol=1e-5, rtol=1e-3), "large_random_test costs mismatch." assert np.allclose(pt_grads, np_grads, atol=1e-5, rtol=1e-3), "large_random_test gradient mismatch." + assert np.allclose(ag_grads, np_grads, atol=1e-5, rtol=1e-3), "large_random_test gradient mismatch." @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES)