Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for concat map dataset #5133

Merged
merged 10 commits into from
Nov 15, 2022
129 changes: 63 additions & 66 deletions nemo/collections/common/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List
from typing import Any, List, Optional, Tuple

import numpy as np
import torch.utils.data as pt_data
Expand Down Expand Up @@ -189,86 +189,83 @@ class ConcatMapDataset(Dataset):
sampling_temperature (int): Temperature value for sampling. Only used when sampling_technique = 'temperature'.
Defaults to 5.
sampling_probabilities (list): Probability values for sampling. Only used when sampling_technique = 'random'.
seed: Optional value to seed the numpy RNG.
"""

def __init__(
self,
datasets: List[Any],
sampling_technique: str = 'temperature',
sampling_temperature: int = 5,
sampling_probabilities: List[float] = None,
consumed_samples: int = 0,
sampling_probabilities: Optional[List[float]] = None,
seed: Optional[int] = None,
):
super().__init__()
self.datasets = datasets
self.sampling_kwargs = {}
self.size = 0
self.lengths = [len(x) for x in self.datasets]
self.sampling_technique = sampling_technique
self.sampling_temperature = sampling_temperature
self.sampling_probabilities = sampling_probabilities
self.consumed_samples = consumed_samples
self.np_rng = np.random.RandomState(consumed_samples)

for dataset in datasets:
self.size += len(dataset)

# Pointer into the next index to fetch from each dataset
self.dataset_index = np.zeros(len(self.datasets), dtype=np.uint8)
self.permuted_dataset_indices = []
for dataset in self.datasets:
permuted_indices = np.arange(len(dataset))
self.np_rng.shuffle(permuted_indices)
self.permuted_dataset_indices.append(permuted_indices)

if self.sampling_technique == 'temperature':
lengths = []
for dataset in datasets:
lengths.append(len(dataset))

p = np.array(lengths) / np.sum(lengths)
p = np.power(p, 1 / self.sampling_temperature)
self.np_rng = np.random.RandomState(seed)

# Build a list of size `len(self)`. Each tuple contains (dataset_id, dataset_index)
self.indices: List[Tuple[int, int]] = []
# Current position as we consume indices from each data set
dataset_positions = [0] * len(self.datasets)
# Random permutation of each dataset. Will be regenerated when exhausted.
shuffled_indices = [self.np_rng.permutation(len(x)) for x in self.datasets]
# Build the list of randomly-chosen datasets spanning the entire length, adhering to sampling technique
if self.sampling_technique == "round-robin":
# To exhaust longest dataset, need to draw `num_datasets * max_dataset_len` samples
total_length = max(self.lengths) * len(self.lengths)
# For round robin, iterate through each dataset
dataset_ids = np.arange(total_length) % len(self.datasets)
for dataset_id in dataset_ids:
position = dataset_positions[dataset_id]
index = shuffled_indices[dataset_id][position]
self.indices.append((dataset_id, index))
dataset_positions[dataset_id] += 1
if dataset_positions[dataset_id] == len(shuffled_indices[dataset_id]):
dataset_positions[dataset_id] = 0
shuffled_indices[dataset_id] = self.np_rng.permutation(len(self.datasets[dataset_id]))
else:
# Resolve probabilities of drawing from each data set
if self.sampling_technique == "random":
if sampling_probabilities is None or len(sampling_probabilities) != len(self.datasets):
raise ValueError(
f"Need {len(self.datasets)} probabilities; got "
f"{len(sampling_probabilities) if sampling_probabilities is not None else 'None'}"
)
p = np.array(self.sampling_probabilities)
elif self.sampling_technique == "temperature":
p = np.array([len(x) for x in self.datasets])
p = np.power(p, 1 / self.sampling_temperature)
else:
raise ValueError(f"Couldn't interpret sampling technique: {sampling_technique}")
# Normalize probabilities
p = p / np.sum(p)
self.p = p

elif self.sampling_technique == 'random':
if not self.sampling_probabilities:
raise ValueError(
"Random generator expects a 'sampling_probabilities' - a list of probability values corresponding to each dataset."
)

if len(self.sampling_probabilities) != len(self.datasets):
raise ValueError(
f"Length of probabilities list must be equal to the number of datasets. Found {len(sampling_probabilities)} probs and {len(self.datasets)} datasets."
)

p = np.array(self.sampling_probabilities)
self.p = p / np.sum(p) # Ensure probabilities sum to 1
# Will randomly choose from datasets
choices = np.arange(len(self.datasets))
# Keep going until largest dataset is exhausted.
exhausted_datasets = set()
while len(exhausted_datasets) < len(self.datasets):
# Randomly choose a dataset for each position in accordance with p
dataset_id = self.np_rng.choice(a=choices, p=p)
dataset = self.datasets[dataset_id]
# Pick next index from dataset
position = dataset_positions[dataset_id]
index = shuffled_indices[dataset_id][position]
self.indices.append((dataset_id, index))
# Maybe reset this dataset's permutation
dataset_positions[dataset_id] += 1
if dataset_positions[dataset_id] >= len(dataset):
shuffled_indices[dataset_id] = self.np_rng.permutation(len(dataset))
dataset_positions[dataset_id] = 0
exhausted_datasets.add(dataset_id)

def __len__(self):
return self.size

def _get_dataset_index(self, idx):
if self.sampling_technique == 'temperature' or self.sampling_technique == 'random':
return self.np_rng.choice(np.arange(len(self.datasets)), p=self.p)
elif self.sampling_technique == 'round-robin':
return idx % len(self.datasets)
return len(self.indices)

def __getitem__(self, idx):
# Get the dataset we want to sample from
dataset_index = self._get_dataset_index(idx)

# Get the index of the sample we want to fetch from the dataset
sample_idx = self.dataset_index[dataset_index]

# If the sample idx > dataset size, reset to 0.
if sample_idx > len(self.datasets[dataset_index]):
sample_idx = 0
self.dataset_index[dataset_index] = 0

# Sample index -> shuffled sample index
shuffled_sample_idx = self.permuted_dataset_indices[dataset_index][sample_idx]

sample = self.datasets[dataset_index][shuffled_sample_idx]
self.dataset_index[dataset_index] += 1

return sample
dataset_id, dataset_index = self.indices[idx]
return self.datasets[dataset_id][dataset_index]