Skip to content

Commit

Permalink
Significantly speed up RNNT loss on CUDA (#3653)
Browse files Browse the repository at this point in the history
* Significantly speed up RNNT loss on CUDA

Signed-off-by: smajumdar <titu1994@gmail.com>

* Add tests

Signed-off-by: smajumdar <titu1994@gmail.com>

* Add tests

Signed-off-by: smajumdar <titu1994@gmail.com>
  • Loading branch information
titu1994 authored Feb 15, 2022
1 parent de66b87 commit 2e16de3
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 32 deletions.
7 changes: 2 additions & 5 deletions nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,6 @@ def rnnt_loss_gpu(
### VIEW TENSORS AS VECTORS FOR POINTER INDEXING ###
acts, acts_shape = rnnt_helper.flatten_tensor(acts)

### REPRESENT THE CUDA ARRAY INTERFACE OF COSTS VECTOR ###
costs_repr = cuda.as_cuda_array(costs, sync=False) # NO COPY OF DATA, JUST CHANGE REPRESENTATION

wrapper = gpu_rnnt.GPURNNT(
minibatch=minibatch_size,
maxT=maxT,
Expand All @@ -210,7 +207,7 @@ def rnnt_loss_gpu(
if grads is None:
status = wrapper.score_forward(
acts=acts.data,
costs=costs_repr,
costs=costs.data,
pad_labels=labels.data,
label_lengths=label_lengths.data,
input_lengths=input_lengths.data,
Expand All @@ -226,7 +223,7 @@ def rnnt_loss_gpu(
status = wrapper.cost_and_grad(
acts=acts.data,
grads=grads.data,
costs=costs_repr,
costs=costs.data,
pad_labels=labels.data,
label_lengths=label_lengths.data,
input_lengths=input_lengths.data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
# limitations under the License.

import multiprocessing
from typing import Optional
from typing import Optional, Tuple

import numba
import torch
from numba import cuda

from nemo.collections.asr.parts.numba.rnnt_loss.utils import global_constants
from nemo.collections.asr.parts.numba.rnnt_loss.utils import global_constants, rnnt_helper
from nemo.collections.asr.parts.numba.rnnt_loss.utils.cuda_utils import gpu_rnnt_kernel, reduce


Expand Down Expand Up @@ -83,6 +83,7 @@ def __init__(

if num_threads > 0:
numba.set_num_threads(min(multiprocessing.cpu_count(), num_threads))
self.num_threads_ = numba.get_num_threads()
else:
self.num_threads_ = numba.get_num_threads()

Expand Down Expand Up @@ -147,27 +148,12 @@ def compute_cost_and_score(
An enum that either represents a successful RNNT operation or failure.
"""
training = grads is not None
used_offset = 0

# // denom
denom = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_]
used_offset += self.maxT_ * self.maxU_ * self.minibatch_

# // alphas & betas
alphas = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_]
used_offset += self.maxT_ * self.maxU_ * self.minibatch_
betas = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_]
used_offset += self.maxT_ * self.maxU_ * self.minibatch_

# // logllh
llForward = self.gpu_workspace[used_offset : used_offset + self.minibatch_]
used_offset += self.minibatch_
llBackward = self.gpu_workspace[used_offset : used_offset + self.minibatch_]
used_offset += self.minibatch_

if training:
grads *= 0.0 # zero grads

used_offset, (denom, alphas, betas, llForward, llBackward) = self._prepare_workspace()

######## START EXECUTION ########
self.log_softmax(acts, denom)

Expand Down Expand Up @@ -226,16 +212,19 @@ def compute_cost_and_score(
self.clamp_,
)

# // cost
costs.copy_to_device(llForward, stream=self.stream_)
# // cost copy, negate (for log likelihood) and update with additional regularizers
# This needs to be done via CUDA, because we used temporary memory llForward
# passed to alpha, which was updated with log likelihoods.
# But copying this data into a pytorch pointer is more difficult (numba api is one way)
# Therefore launch a pointwise CUDA kernel to update the costs inplace from data of llForward
# Then negate to compute the loglikelihood.
threadsperblock = min(costs.shape[0], 32)
blockspergrid = (costs.shape[0] + (threadsperblock - 1)) // threadsperblock
rnnt_helper.compute_costs_data[blockspergrid, threadsperblock, self.stream_, 0](
llForward, costs, self.fastemit_lambda_
)
self.stream_.synchronize()

# compute negative log likelihood.
for mb in range(self.minibatch_):
# Scale llForward by FastEmit lambda
costs[mb] = -costs[mb]
costs[mb] = (1.0 + self.fastemit_lambda_) * costs[mb]

return global_constants.RNNTStatus.RNNT_STATUS_SUCCESS

def cost_and_grad(
Expand Down Expand Up @@ -271,3 +260,31 @@ def score_forward(
return global_constants.RNNTStatus.RNNT_STATUS_INVALID_VALUE

return self.compute_cost_and_score(acts, None, costs, pad_labels, label_lengths, input_lengths)

def _prepare_workspace(self) -> (int, Tuple[torch.Tensor]):
"""
Helper method that uses the workspace and constructs slices of it that can be used.
Returns:
An int, representing the offset of the used workspace (practically, the slice of the workspace consumed)
A tuple of tensors representing the shared workspace.
"""
used_offset = 0

# // denom
denom = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_]
used_offset += self.maxT_ * self.maxU_ * self.minibatch_

# // alphas & betas
alphas = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_]
used_offset += self.maxT_ * self.maxU_ * self.minibatch_
betas = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_]
used_offset += self.maxT_ * self.maxU_ * self.minibatch_

# // logllh
llForward = self.gpu_workspace[used_offset : used_offset + self.minibatch_]
used_offset += self.minibatch_
llBackward = self.gpu_workspace[used_offset : used_offset + self.minibatch_]
used_offset += self.minibatch_

return used_offset, (denom, alphas, betas, llForward, llBackward)
18 changes: 18 additions & 0 deletions nemo/collections/asr/parts/numba/rnnt_loss/utils/rnnt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,24 @@ def log_plus(p1: float, p2: float):
return result


@cuda.jit(device=True, inline=True)
def copy_data_1d(source: torch.Tensor, dest: torch.Tensor, idx: int):
dest[idx] = source[idx]


@cuda.jit()
def compute_costs_data(source: torch.Tensor, dest: torch.Tensor, fastemit_lambda: float):
block = cuda.blockIdx.x
tid = cuda.threadIdx.x
idx = block * cuda.blockDim.x + tid
length = source.shape[0]

if idx < length:
copy_data_1d(source, dest, idx)
dest[idx] *= -1.0
dest[idx] *= 1.0 + fastemit_lambda


def get_workspace_size(
maxT: int, maxU: int, minibatch: int, gpu: bool
) -> (Optional[int], global_constants.RNNTStatus):
Expand Down
32 changes: 32 additions & 0 deletions tests/collections/asr/numba/rnnt_loss/utils/test_rnnt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,38 @@ def _kernel(x, y):
for i in range(len(x_new)):
assert x_new[i] == z[i]

@pytest.mark.skipif(not cuda.is_available(), reason="CUDA Helpers can only be run when CUDA is available")
@pytest.mark.parametrize('batch_size', [8, 128, 256])
@pytest.mark.parametrize('fastemit_lambda', [0.0, 0.001])
@pytest.mark.unit
def test_compute_costs_data(self, batch_size, fastemit_lambda):
numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__)

x = np.full([batch_size], fill_value=0.0) # np.random.rand(8192)
y = np.random.randn(batch_size) # np.random.rand(8192)

stream = cuda.stream()
x_c = cuda.to_device(x, stream=stream)
y_c = cuda.to_device(y, stream=stream)

# call kernel
threads_per_block = min(x.shape[0], 32)
blocks_per_grid = (x.shape[0] + (threads_per_block - 1)) // threads_per_block
# Kernel call (source, dest, extra_args_...)
rnnt_helper.compute_costs_data[blocks_per_grid, threads_per_block, stream](y_c, x_c, fastemit_lambda)

# sync kernel
stream.synchronize()

x_new = x_c.copy_to_host(stream=stream)
del x_c, y_c

res = -(y.copy())
res *= 1.0 + fastemit_lambda

for i in range(len(x_new)):
assert x_new[i] == res[i], f"index failed {i}"


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 2e16de3

Please sign in to comment.