Skip to content

Commit cdaab83

Browse files
ultronsbilgeacun
andcommitted
Intermediate experiments wav2vec
input shape temp update clean up dataset updates clean up move tensor idx to matrix op inside apply_mask use tensor operators to replace tensor indexing, passed consistency test verification Minor improvements Fix bucketpadlendataset Moved mask matrices creation to dataset prep. Remove dynamism, apply mask correctly, add some guardrails, some cleanups. Send device data to cpu b4 logging. Fix data bucketing for RawAudioDataset, refactor bucketing functions, fix filling w/ -inf in wav2vec2, minor cleanups Sample size computeation during data prep to reduce atens, dont call item in log_scalar, minor cleanups Remove extra validation atens, clean up marking step and sending to cpu. Correct loss computation for w2v2 criterion + refactor index_put Fix bug in index_put + fix integer division Dont call float on extra logs, clean up comment. Correct accuracy computation, refactor xla tensor check. Adjust loss computation so it works w/ binary cross entropy. Remove sending log outputs back to cpu after allreduce. Dont sample padded states when sampling negatives + correct mi in loss computation. Fixing config issues after rebase Fix bug in negatives from everywhere Fixing config issue for TPU after rebase Taylans changes on top of rebase Use float on cpu if fp16 when filling w/ -inf in w2v2 (#5) * Use float on cpu if fp16 when filling w/ -inf in w2v2 * xla -> self.xla * make logging_output_can_be_summed a regular method instead of staticmethod. Make tpu codepath work w/ hydra. (#6) * Make tpu codepath work w/ hydra. * Share and pass down model related args to rawaudiodataset correctly. * fp16 bug fix on non-xla devices (self._inftensor change) * use index_put to avoid dynamicity in model's fwd. * Get rid of some unnecessary warnings for tpus to clean up stderr. * Send logging outputs to cpu b4 logging to reduce atens. * Util function to move cpu tensors to tpu. * Use the util function to handle dummy batches to avoid crash at the end of epoch in distributed training. * fixing configs for precompute mask indices Co-authored-by: Bilge Acun <acun@fb.com>
1 parent 3aeb8fe commit cdaab83

14 files changed

+8977
-112
lines changed

fairseq/criterions/wav2vec_criterion.py

+45-20
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class Wav2VecCriterionConfig(FairseqDataclass):
3131
default_factory=lambda: [],
3232
metadata={"help": "output keys to log"},
3333
)
34-
34+
from fairseq.utils import index_put, is_xla_tensor
3535

3636
@register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig)
3737
class Wav2vecCriterion(FairseqCriterion):
@@ -52,7 +52,9 @@ def forward(self, model, sample, reduce=True):
5252
net_output = model(**sample["net_input"])
5353
logits = model.get_logits(net_output).float()
5454
target = model.get_targets(sample, net_output)
55+
self.xla = is_xla_tensor(logits)
5556

57+
# XXX: handle weights on xla.
5658
weights = None
5759
if hasattr(model, "get_target_weights") and not self.infonce:
5860
weights = model.get_target_weights(target, net_output)
@@ -61,21 +63,31 @@ def forward(self, model, sample, reduce=True):
6163

6264
losses = []
6365

66+
reduction = "none" if ((not reduce) or self.xla) else "sum"
6467
if self.infonce:
65-
loss = F.cross_entropy(
66-
logits,
67-
target,
68-
reduction="sum" if reduce else "none",
69-
)
68+
loss = F.cross_entropy(logits, target, reduction=reduction)
7069
else:
7170
loss = F.binary_cross_entropy_with_logits(
72-
logits,
73-
target.float(),
74-
weights,
75-
reduction="sum" if reduce else "none",
71+
logits, target.float(), weights, reduction=reduction
72+
)
73+
74+
if self.xla:
75+
# tpu-comment: since dynamic shapes lead to recompilations on xla,
76+
# we don't shrink tensors using mask_indices.
77+
# Instead, we use mask indices to adjust loss.
78+
mi = (
79+
sample['net_input']['mask_indices']
80+
.transpose(0, 1) # logits are transposed in `model.get_logits`
81+
.reshape(logits.size(0))
7682
)
83+
loss = (loss * mi).sum() if reduce else (loss * mi)
7784

78-
sample_size = target.numel() if self.infonce else target.long().sum().item()
85+
if 'sample_size' in sample and self.infonce:
86+
sample_size = sample['sample_size']
87+
elif 'mask_indices' in sample['net_input']:
88+
sample_size = sample['net_input']['mask_indices'].sum()
89+
else:
90+
sample_size = target.numel() if self.infonce else target.long().sum().item()
7991
losses.append(loss.detach().clone())
8092

8193
if self.loss_weights is not None:
@@ -95,7 +107,7 @@ def forward(self, model, sample, reduce=True):
95107
losses.append(p)
96108

97109
logging_output = {
98-
"loss": loss.item() if reduce else loss,
110+
"loss": loss.item() if (reduce and not self.xla) else loss.detach(),
99111
"ntokens": sample_size,
100112
"nsentences": sample["id"].numel(),
101113
"sample_size": sample_size,
@@ -111,11 +123,14 @@ def forward(self, model, sample, reduce=True):
111123
if not self.training:
112124
logging_output["target"] = target.cpu().numpy()
113125
elif lk in net_output:
114-
logging_output[lk] = float(net_output[lk])
126+
value = net_output[lk]
127+
if not is_xla_tensor(value):
128+
value = float(value)
129+
logging_output[lk] = value
115130

116131
if len(losses) > 1:
117132
for i, l in enumerate(losses):
118-
logging_output[f"loss_{i}"] = l.item()
133+
logging_output[f"loss_{i}"] = l.item() if not self.xla else l.detach()
119134

120135
if self.infonce:
121136
with torch.no_grad():
@@ -126,9 +141,15 @@ def forward(self, model, sample, reduce=True):
126141
assert logits.dim() > 1, logits.shape
127142
max = logits.argmax(-1) == 0
128143
min = logits.argmin(-1) == 0
129-
both = max & min
130-
corr = max.long().sum().item() - both.long().sum().item()
131-
count = max.numel()
144+
if is_xla_tensor(logits):
145+
max, min = max * mi, min * mi
146+
both = max & min
147+
corr = max.long().sum() - both.long().sum()
148+
count = mi.sum()
149+
else:
150+
both = max & min
151+
corr = max.long().sum().item() - both.long().sum().item()
152+
count = float(max.numel())
132153

133154
logging_output["correct"] = corr
134155
logging_output["count"] = count
@@ -188,11 +209,15 @@ def reduce_metrics(logging_outputs) -> None:
188209
else:
189210
metrics.log_scalar(k, val / len(logging_outputs), round=3)
190211

191-
@staticmethod
192-
def logging_outputs_can_be_summed() -> bool:
212+
# FIXME: revert when gather based xla reduction is implemented
213+
#@staticmethod
214+
#def logging_outputs_can_be_summed() -> bool:
215+
def logging_outputs_can_be_summed(self) -> bool:
193216
"""
194217
Whether the logging outputs returned by `forward` can be summed
195218
across workers prior to calling `reduce_metrics`. Setting this
196219
to True will improves distributed training speed.
197220
"""
198-
return False
221+
# XXX: Gather based reduction not implemented for xla yet.
222+
# So we fall to sum based reduction for xla.
223+
return self.xla

fairseq/data/audio/raw_audio_dataset.py

+121-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
import torch
1313
import torch.nn.functional as F
1414

15-
from .. import FairseqDataset
15+
from .. import FairseqDataset, BaseWrapperDataset
16+
from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes
1617

1718

1819
logger = logging.getLogger(__name__)
@@ -28,6 +29,8 @@ def __init__(
2829
min_length=0,
2930
pad=False,
3031
normalize=False,
32+
compute_mask_indices=False,
33+
**mask_compute_kwargs,
3134
):
3235
super().__init__()
3336

@@ -41,6 +44,14 @@ def __init__(
4144
self.pad = pad
4245
self.shuffle = shuffle
4346
self.normalize = normalize
47+
self.compute_mask_indices = compute_mask_indices
48+
if self.compute_mask_indices:
49+
self.mask_compute_kwargs = mask_compute_kwargs
50+
self._features_size_map = {}
51+
self._C = mask_compute_kwargs['encoder_embed_dim']
52+
self._conv_feature_layers = eval(
53+
mask_compute_kwargs['conv_feature_layers']
54+
)
4455

4556
def __getitem__(self, index):
4657
raise NotImplementedError()
@@ -72,6 +83,45 @@ def crop_to_max_size(self, wav, target_size):
7283
end = size - diff + start
7384
return wav[start:end]
7485

86+
def _compute_mask_indices(self, dims, padding_mask):
87+
B, T, C = dims
88+
mask_indices, mask_channel_indices = None, None
89+
if self.mask_compute_kwargs['mask_prob'] > 0:
90+
mask_indices = compute_mask_indices(
91+
(B, T),
92+
padding_mask,
93+
self.mask_compute_kwargs['mask_prob'],
94+
self.mask_compute_kwargs['mask_length'],
95+
self.mask_compute_kwargs['mask_selection'],
96+
self.mask_compute_kwargs['mask_other'],
97+
min_masks=2,
98+
no_overlap=self.mask_compute_kwargs['no_mask_overlap'],
99+
min_space=self.mask_compute_kwargs['mask_min_space'],
100+
)
101+
mask_indices = torch.from_numpy(mask_indices)
102+
if self.mask_compute_kwargs['mask_channel_prob'] > 0:
103+
mask_channel_indices = compute_mask_indices(
104+
(B, C),
105+
None,
106+
self.mask_compute_kwargs['mask_channel_prob'],
107+
self.mask_compute_kwargs['mask_channel_length'],
108+
self.mask_compute_kwargs['mask_channel_selection'],
109+
self.mask_compute_kwargs['mask_channel_other'],
110+
no_overlap=self.mask_compute_kwargs['no_mask_channel_overlap'],
111+
min_space=self.mask_compute_kwargs['mask_channel_min_space'],
112+
)
113+
mask_channel_indices = (
114+
torch.from_numpy(mask_channel_indices)
115+
.unsqueeze(1)
116+
.expand(-1, T, -1)
117+
)
118+
119+
return mask_indices, mask_channel_indices
120+
121+
@staticmethod
122+
def _bucket_tensor(tensor, num_pad, value):
123+
return F.pad(tensor, (0, num_pad), value=value)
124+
75125
def collater(self, samples):
76126
samples = [s for s in samples if s["source"] is not None]
77127
if len(samples) == 0:
@@ -103,9 +153,55 @@ def collater(self, samples):
103153
collated_sources[i] = self.crop_to_max_size(source, target_size)
104154

105155
input = {"source": collated_sources}
156+
out = {"id": torch.LongTensor([s["id"] for s in samples])}
106157
if self.pad:
107158
input["padding_mask"] = padding_mask
108-
return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input}
159+
160+
if hasattr(self, 'num_buckets') and self.num_buckets > 0:
161+
assert self.pad, "Cannot bucket without padding first."
162+
bucket = max(self._bucketed_sizes[s['id']] for s in samples)
163+
num_pad = bucket - collated_sources.size(-1)
164+
if num_pad:
165+
input['source'] = self._bucket_tensor(
166+
collated_sources, num_pad, 0
167+
)
168+
input['padding_mask'] = self._bucket_tensor(
169+
padding_mask, num_pad, True
170+
)
171+
172+
if self.compute_mask_indices:
173+
B = input['source'].size(0)
174+
T = self._get_mask_indices_dims(input['source'].size(-1))
175+
padding_mask_reshaped = input['padding_mask'].clone()
176+
extra = padding_mask_reshaped.size(1) % T
177+
if extra > 0:
178+
padding_mask_reshaped = padding_mask_reshaped[:, :-extra]
179+
padding_mask_reshaped = padding_mask_reshaped.view(
180+
padding_mask_reshaped.size(0), T, -1
181+
)
182+
padding_mask_reshaped = padding_mask_reshaped.all(-1)
183+
input['padding_count'] = (
184+
padding_mask_reshaped.sum(-1).max().item()
185+
)
186+
mask_indices, mask_channel_indices = self._compute_mask_indices(
187+
(B, T, self._C), padding_mask_reshaped,
188+
)
189+
input["mask_indices"] = mask_indices
190+
input["mask_channel_indices"] = mask_channel_indices
191+
out['sample_size'] = mask_indices.sum().item()
192+
193+
out["net_input"] = input
194+
return out
195+
196+
def _get_mask_indices_dims(self, size, padding=0, dilation=1):
197+
if size not in self._features_size_map:
198+
L_in = size
199+
for (_, kernel_size, stride) in self._conv_feature_layers:
200+
L_out = L_in + 2*padding - dilation*(kernel_size-1) - 1
201+
L_out = 1 + L_out // stride
202+
L_in = L_out
203+
self._features_size_map[size] = L_out
204+
return self._features_size_map[size]
109205

110206
def num_tokens(self, index):
111207
return self.size(index)
@@ -141,6 +237,9 @@ def __init__(
141237
min_length=0,
142238
pad=False,
143239
normalize=False,
240+
num_buckets=0,
241+
compute_mask_indices=False,
242+
**mask_compute_kwargs,
144243
):
145244
super().__init__(
146245
sample_rate=sample_rate,
@@ -150,6 +249,8 @@ def __init__(
150249
min_length=min_length,
151250
pad=pad,
152251
normalize=normalize,
252+
compute_mask_indices=compute_mask_indices,
253+
**mask_compute_kwargs,
153254
)
154255

155256
self.fnames = []
@@ -168,8 +269,26 @@ def __init__(
168269
self.fnames.append(items[0])
169270
self.line_inds.add(i)
170271
self.sizes.append(sz)
272+
self.set_bucket_info(num_buckets)
171273
logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")
172274

275+
def set_bucket_info(self, num_buckets):
276+
self.num_buckets = num_buckets
277+
if self.num_buckets > 0:
278+
self._collated_sizes = np.minimum(
279+
np.array(self.sizes), self.max_sample_size,
280+
)
281+
self.buckets = get_buckets(
282+
self._collated_sizes, self.num_buckets,
283+
)
284+
self._bucketed_sizes = get_bucketed_sizes(
285+
self._collated_sizes, self.buckets
286+
)
287+
logger.info(
288+
f"{len(self.buckets)} bucket(s) for the audio dataset: "
289+
f"{self.buckets}"
290+
)
291+
173292
def __getitem__(self, index):
174293
import soundfile as sf
175294

fairseq/data/bucket_pad_length_dataset.py

+24-22
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import torch.nn.functional as F
88
from fairseq.data import BaseWrapperDataset
9+
from fairseq.data.data_utils import get_buckets, get_bucketed_sizes
910

1011

1112
class BucketPadLengthDataset(BaseWrapperDataset):
@@ -29,42 +30,43 @@ def __init__(
2930
num_buckets,
3031
pad_idx,
3132
left_pad,
33+
tensor_key=None,
3234
):
3335
super().__init__(dataset)
3436
self.pad_idx = pad_idx
3537
self.left_pad = left_pad
3638

3739
assert num_buckets > 0
38-
self.buckets = np.unique(
39-
np.percentile(
40-
sizes,
41-
np.linspace(0, 100, num_buckets + 1),
42-
interpolation="lower",
43-
)[1:]
44-
)
40+
self.buckets = get_buckets(sizes, num_buckets)
41+
self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
42+
self._tensor_key = tensor_key
4543

46-
def get_bucketed_sizes(orig_sizes, buckets):
47-
sizes = np.copy(orig_sizes)
48-
assert np.min(sizes) >= 0
49-
start_val = -1
50-
for end_val in buckets:
51-
mask = (sizes > start_val) & (sizes <= end_val)
52-
sizes[mask] = end_val
53-
start_val = end_val
54-
return sizes
44+
def _set_tensor(self, item, val):
45+
if self._tensor_key is None:
46+
return val
47+
item[self._tensor_key] = val
48+
return item
5549

56-
self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
50+
def _get_tensor(self, item):
51+
if self._tensor_key is None:
52+
return item
53+
return item[self._tensor_key]
5754

58-
def __getitem__(self, index):
59-
item = self.dataset[index]
60-
bucket_size = self._bucketed_sizes[index]
61-
num_pad = bucket_size - item.size(-1)
55+
def _pad(self, tensor, bucket_size, dim=-1):
56+
num_pad = bucket_size - tensor.size(dim)
6257
return F.pad(
63-
item,
58+
tensor,
6459
(num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad),
6560
value=self.pad_idx,
6661
)
6762

63+
def __getitem__(self, index):
64+
item = self.dataset[index]
65+
bucket_size = self._bucketed_sizes[index]
66+
tensor = self._get_tensor(item)
67+
padded = self._pad(tensor, bucket_size)
68+
return self._set_tensor(item, padded)
69+
6870
@property
6971
def sizes(self):
7072
return self._bucketed_sizes

0 commit comments

Comments
 (0)