Skip to content

Commit

Permalink
typo in get_feat_range_matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
popcornell committed Jan 31, 2024
1 parent 5393298 commit 68f115b
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions nemo/collections/asr/models/msdd_v2_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,7 +966,7 @@ def get_feature_index_map(
feature_count_range = target_timestamps[:, 1] - target_timestamps[:, 0]

# Pre-compute feature indices for faster assigning:
feature_frame_length_range, feature_frame_interval_range= self.get_feat_range_matirx(max_feat_len=processed_signal.shape[2],
feature_frame_length_range, feature_frame_interval_range= self.get_feat_range_matrix(max_feat_len=processed_signal.shape[2],
feature_count_range=feature_count_range,
target_timestamps=target_timestamps,
device=processed_signal.device)
Expand Down Expand Up @@ -1055,7 +1055,7 @@ def forward_multiscale(
vad_probs_steps = None
return embs, pools, vad_probs_steps, vad_prob_segments

def get_feat_range_matirx(self, max_feat_len, feature_count_range, target_timestamps, device):
def get_feat_range_matrix(self, max_feat_len, feature_count_range, target_timestamps, device):
"""
"""
feat_index_range = torch.arange(0, max_feat_len).to(device)
Expand All @@ -1068,7 +1068,7 @@ def reshape_vad_frames(self, vad_probs_frame, max_feat_len, ms_seg_timestamps, m
max_seq_len = torch.min(torch.tensor([ms_seg_timestamps.shape[2], max_seq_len]))
target_timestamps = ms_seg_timestamps[0, -1].to(torch.int64)
feature_count_range = target_timestamps[:, 1] - target_timestamps[:, 0]
_, ffir = self.get_feat_range_matirx(max_feat_len, feature_count_range, target_timestamps, device=ms_seg_timestamps.device)
_, ffir = self.get_feat_range_matrix(max_feat_len, feature_count_range, target_timestamps, device=ms_seg_timestamps.device)
vad_probs_steps = vad_probs_frame[:, ffir].reshape(vad_probs_frame.shape[0], max_seq_len, -1).mean(dim=2)
return vad_probs_steps

Expand Down Expand Up @@ -2187,4 +2187,4 @@ def list_available_models(cls) -> List[PretrainedModelInfo]:
Returns:
List of available pre-trained models.
"""
return EncDecDiarLabelModel.list_available_models()
return EncDecDiarLabelModel.list_available_models()

0 comments on commit 68f115b

Please sign in to comment.