Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Significantly speed up RNNT loss on CUDA #3653

Merged
merged 3 commits into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,6 @@ def rnnt_loss_gpu(
### VIEW TENSORS AS VECTORS FOR POINTER INDEXING ###
acts, acts_shape = rnnt_helper.flatten_tensor(acts)

### REPRESENT THE CUDA ARRAY INTERFACE OF COSTS VECTOR ###
costs_repr = cuda.as_cuda_array(costs, sync=False) # NO COPY OF DATA, JUST CHANGE REPRESENTATION

wrapper = gpu_rnnt.GPURNNT(
minibatch=minibatch_size,
maxT=maxT,
Expand All @@ -210,7 +207,7 @@ def rnnt_loss_gpu(
if grads is None:
status = wrapper.score_forward(
acts=acts.data,
costs=costs_repr,
costs=costs.data,
pad_labels=labels.data,
label_lengths=label_lengths.data,
input_lengths=input_lengths.data,
Expand All @@ -226,7 +223,7 @@ def rnnt_loss_gpu(
status = wrapper.cost_and_grad(
acts=acts.data,
grads=grads.data,
costs=costs_repr,
costs=costs.data,
pad_labels=labels.data,
label_lengths=label_lengths.data,
input_lengths=input_lengths.data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
# limitations under the License.

import multiprocessing
from typing import Optional
from typing import Optional, Tuple

import numba
import torch
from numba import cuda

from nemo.collections.asr.parts.numba.rnnt_loss.utils import global_constants
from nemo.collections.asr.parts.numba.rnnt_loss.utils import global_constants, rnnt_helper
from nemo.collections.asr.parts.numba.rnnt_loss.utils.cuda_utils import gpu_rnnt_kernel, reduce


Expand Down Expand Up @@ -83,6 +83,7 @@ def __init__(

if num_threads > 0:
numba.set_num_threads(min(multiprocessing.cpu_count(), num_threads))
self.num_threads_ = numba.get_num_threads()
else:
self.num_threads_ = numba.get_num_threads()

Expand Down Expand Up @@ -147,27 +148,12 @@ def compute_cost_and_score(
An enum that either represents a successful RNNT operation or failure.
"""
training = grads is not None
used_offset = 0

# // denom
denom = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_]
used_offset += self.maxT_ * self.maxU_ * self.minibatch_

# // alphas & betas
alphas = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_]
used_offset += self.maxT_ * self.maxU_ * self.minibatch_
betas = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_]
used_offset += self.maxT_ * self.maxU_ * self.minibatch_

# // logllh
llForward = self.gpu_workspace[used_offset : used_offset + self.minibatch_]
used_offset += self.minibatch_
llBackward = self.gpu_workspace[used_offset : used_offset + self.minibatch_]
used_offset += self.minibatch_

if training:
grads *= 0.0 # zero grads

used_offset, (denom, alphas, betas, llForward, llBackward) = self._prepare_workspace()

######## START EXECUTION ########
self.log_softmax(acts, denom)

Expand Down Expand Up @@ -226,16 +212,19 @@ def compute_cost_and_score(
self.clamp_,
)

# // cost
costs.copy_to_device(llForward, stream=self.stream_)
# // cost copy, negate (for log likelihood) and update with additional regularizers
# This needs to be done via CUDA, because we used temporary memory llForward
# passed to alpha, which was updated with log likelihoods.
# But copying this data into a pytorch pointer is more difficult (numba api is one way)
# Therefore launch a pointwise CUDA kernel to update the costs inplace from data of llForward
# Then negate to compute the loglikelihood.
threadsperblock = min(costs.shape[0], 32)
blockspergrid = (costs.shape[0] + (threadsperblock - 1)) // threadsperblock
rnnt_helper.compute_costs_data[blockspergrid, threadsperblock, self.stream_, 0](
llForward, costs, self.fastemit_lambda_
)
self.stream_.synchronize()

# compute negative log likelihood.
for mb in range(self.minibatch_):
# Scale llForward by FastEmit lambda
costs[mb] = -costs[mb]
costs[mb] = (1.0 + self.fastemit_lambda_) * costs[mb]

return global_constants.RNNTStatus.RNNT_STATUS_SUCCESS

def cost_and_grad(
Expand Down Expand Up @@ -271,3 +260,31 @@ def score_forward(
return global_constants.RNNTStatus.RNNT_STATUS_INVALID_VALUE

return self.compute_cost_and_score(acts, None, costs, pad_labels, label_lengths, input_lengths)

def _prepare_workspace(self) -> (int, Tuple[torch.Tensor]):
"""
Helper method that uses the workspace and constructs slices of it that can be used.

Returns:
An int, representing the offset of the used workspace (practically, the slice of the workspace consumed)
A tuple of tensors representing the shared workspace.
"""
used_offset = 0

# // denom
denom = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_]
used_offset += self.maxT_ * self.maxU_ * self.minibatch_

# // alphas & betas
alphas = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_]
used_offset += self.maxT_ * self.maxU_ * self.minibatch_
betas = self.gpu_workspace[used_offset : used_offset + self.maxT_ * self.maxU_ * self.minibatch_]
used_offset += self.maxT_ * self.maxU_ * self.minibatch_

# // logllh
llForward = self.gpu_workspace[used_offset : used_offset + self.minibatch_]
used_offset += self.minibatch_
llBackward = self.gpu_workspace[used_offset : used_offset + self.minibatch_]
used_offset += self.minibatch_

return used_offset, (denom, alphas, betas, llForward, llBackward)
18 changes: 18 additions & 0 deletions nemo/collections/asr/parts/numba/rnnt_loss/utils/rnnt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,24 @@ def log_plus(p1: float, p2: float):
return result


@cuda.jit(device=True, inline=True)
def copy_data_1d(source: torch.Tensor, dest: torch.Tensor, idx: int):
dest[idx] = source[idx]


@cuda.jit()
def compute_costs_data(source: torch.Tensor, dest: torch.Tensor, fastemit_lambda: float):
block = cuda.blockIdx.x
tid = cuda.threadIdx.x
idx = block * cuda.blockDim.x + tid
length = source.shape[0]

if idx < length:
copy_data_1d(source, dest, idx)
dest[idx] *= -1.0
dest[idx] *= 1.0 + fastemit_lambda


def get_workspace_size(
maxT: int, maxU: int, minibatch: int, gpu: bool
) -> (Optional[int], global_constants.RNNTStatus):
Expand Down
32 changes: 32 additions & 0 deletions tests/collections/asr/numba/rnnt_loss/utils/test_rnnt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,38 @@ def _kernel(x, y):
for i in range(len(x_new)):
assert x_new[i] == z[i]

@pytest.mark.skipif(not cuda.is_available(), reason="CUDA Helpers can only be run when CUDA is available")
@pytest.mark.parametrize('batch_size', [8, 128, 256])
@pytest.mark.parametrize('fastemit_lambda', [0.0, 0.001])
@pytest.mark.unit
def test_compute_costs_data(self, batch_size, fastemit_lambda):
numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__)

x = np.full([batch_size], fill_value=0.0) # np.random.rand(8192)
y = np.random.randn(batch_size) # np.random.rand(8192)

stream = cuda.stream()
x_c = cuda.to_device(x, stream=stream)
y_c = cuda.to_device(y, stream=stream)

# call kernel
threads_per_block = min(x.shape[0], 32)
blocks_per_grid = (x.shape[0] + (threads_per_block - 1)) // threads_per_block
# Kernel call (source, dest, extra_args_...)
rnnt_helper.compute_costs_data[blocks_per_grid, threads_per_block, stream](y_c, x_c, fastemit_lambda)

# sync kernel
stream.synchronize()

x_new = x_c.copy_to_host(stream=stream)
del x_c, y_c

res = -(y.copy())
res *= 1.0 + fastemit_lambda

for i in range(len(x_new)):
assert x_new[i] == res[i], f"index failed {i}"


if __name__ == '__main__':
pytest.main([__file__])