Skip to content

Commit

Permalink
forcing conv subsampling to 32 bit (NVIDIA#4293)
Browse files Browse the repository at this point in the history
* forcing conv subsampling to 32 bit

* extra character

* comments added

* moving autocast into subsampling

* lint

Co-authored-by: Vahid Noroozi <VahidooX@users.noreply.github.com>
Signed-off-by: Georg Kucsko <gkucsko@gmail.com>
  • Loading branch information
2 people authored and gkucsko committed Jun 2, 2022
1 parent 3d72e52 commit bb727b3
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion nemo/collections/asr/parts/submodules/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,14 @@ def forward(self, x, lengths):
repeat_num=self._sampling_num,
)
x = x.unsqueeze(1)
x = self.conv(x)
if self._subsampling == 'striding':
# added in order to prevent slowdown in torch.nn.Conv2d with bfloat16 / CUDNN v8 API
# to be removed once the above is fixed in cudnn
with torch.cuda.amp.autocast(dtype=torch.float32):
x = self.conv(x)
else:
x = self.conv(x)

b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).reshape(b, t, -1))
return x, lengths
Expand Down

0 comments on commit bb727b3

Please sign in to comment.