Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time-chunked datasets. #48

Merged
merged 1 commit into from
Feb 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions clrs/_src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@

import dataclasses

import functools
from typing import Iterator

from clrs._src import probing
from clrs._src import samplers
from clrs._src import specs

import jax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

Expand Down Expand Up @@ -154,3 +159,143 @@ def create_dataset(folder, algorithm, split, batch_size):
dataset = dataset.batch(batch_size)
return (dataset.map(lambda d: _preprocess(d, algorithm=algorithm)),
specs.SPECS[algorithm])


def _copy_hint(source, dest, i, start_source, start_dest, to_add):
"""Copy from full-sample hint to a hint chunk."""
assert np.all(dest[start_dest:, i:] == 0)
assert start_dest < dest.shape[0]
assert start_dest + to_add <= dest.shape[0]
assert start_source < source.shape[0]
assert start_source + to_add <= source.shape[0]
dest[start_dest:start_dest+to_add, i] = source[
start_source:start_source+to_add, i]
return dest


def _copy_io(source, dest, i, start_dest, to_add):
"""Copy from an input or output to an input or output chunk."""
assert np.all(dest[start_dest:, i:] == 0)
dest[start_dest:start_dest+to_add, i] = source[i]
return dest


def chunkify(dataset: Iterator[samplers.Feedback], chunk_length: int):
"""Generator of fixed-length chunks from full-trajectory samples.

Args:
dataset: full-sample dataset as numpy iterator.
chunk_length: time length of chunks.
Yields:
Fixed-timelength chunks of data. Each tensor of inputs, hints and outputs
has dimensions chunk_length x batch_size x ... Samples are not time-padded,
after the end of one sample immediately comes the next. Since different
samples can have different time lengths, the beginnings and ends of samples
within a batch do not need to coincide. For this reason, the chunked
dataset features include two chunk_length x batch_size int tensors,
`is_first` and `is_last`, that mark the beginning and end of each sample.
For example, if `chunk_legnth`==6 and `batch_size`==2 and the first
full-sample batch had one sample of length 3 and one of length 5,
we would have a first chunked batch with the following `is_first` and
`is_last` tensors:

is_first = [[1, 1] is_last = [[0, 0] ( sample id [[0 1]
[0, 0] [0, 0] [0 1]
[0, 0] [1, 0] [0 1]
[1, 0] [0, 0] [2 1]
[0, 0] [0, 1] [2 1]
[0, 1]] [0, 0]] [2 3]] )

while the data in the inputs, outputs and hints tensors would correspond
to samples as identified by the sample_id indicated above for reference.
Notice that, while in the full-sample dataset inputs and outputs have
no time dimension, here they do; the input and output tensors are simply
repeated along each sample's time length.
"""
def _get_batch():
d = next(dataset)
return (d.features.inputs, d.features.hints, d.outputs,
d.features.lengths.astype(int))

inputs, hints, outputs, lengths = _get_batch()
for inp in inputs:
if inp.location in [specs.Location.NODE, specs.Location.EDGE]:
batch_size = inp.data.shape[0]
break

io_chunk = lambda x: np.zeros((chunk_length,) + x.shape, dtype=x.dtype)
chunk_inputs = jax.tree_map(io_chunk, inputs)
chunk_outputs = jax.tree_map(io_chunk, outputs)

hint_chunk = lambda x: np.zeros((chunk_length,) + x.shape[1:], dtype=x.dtype)
chunk_hints = jax.tree_map(hint_chunk, hints)

inputs = [inputs]
hints = [hints]
outputs = [outputs]
left = [lengths.copy()]
lengths = [lengths.copy()]

while True:
# Create a new empty chunk
chunk_inputs = jax.tree_map(np.zeros_like, chunk_inputs)
chunk_hints = jax.tree_map(np.zeros_like, chunk_hints)
chunk_outputs = jax.tree_map(np.zeros_like, chunk_outputs)
start_mark = np.zeros((chunk_length, batch_size), dtype=int)
end_mark = np.zeros((chunk_length, batch_size), dtype=int)

# Get enough data batches to fill the new chunk
while np.any(np.sum(left, axis=0) < chunk_length):
inp, hh, out, ll = _get_batch()
inputs.append(inp)
hints.append(hh)
outputs.append(out)
left.append(ll.copy())
lengths.append(ll.copy())

# Fill the chunk, one batch element at a time
for i in range(batch_size):
total, idx = 0, 0
while total < chunk_length:
to_add = min(left[idx][i], chunk_length - total)
if to_add:
start = lengths[idx][i] - left[idx][i]
assert start >= 0
f_io = functools.partial(_copy_io, i=i, start_dest=total,
to_add=to_add)
chunk_inputs = jax.tree_map(f_io, inputs[idx], chunk_inputs)
chunk_outputs = jax.tree_map(f_io, outputs[idx], chunk_outputs)
f_hint = functools.partial(_copy_hint, i=i, start_source=start,
start_dest=total, to_add=to_add)
chunk_hints = jax.tree_map(f_hint, hints[idx], chunk_hints)
if start == 0:
start_mark[total, i] = 1
total += to_add
left[idx][i] -= to_add
assert left[idx][i] >= 0
if left[idx][i] == 0:
end_mark[total - 1, i] = 1
idx += 1
assert total == chunk_length

while left and np.all(left[0] == 0):
inputs.pop(0)
hints.pop(0)
outputs.pop(0)
left.pop(0)
lengths.pop(0)

yield samplers.Feedback(
samplers.FeaturesChunked(chunk_inputs, chunk_hints,
start_mark, end_mark),
chunk_outputs)


def create_chunked_dataset(folder, algorithm, split, batch_size, chunk_length):
dataset = tfds.load(f'clrs_dataset/{algorithm}_{split}',
data_dir=folder, split=split)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
dataset = dataset.map(lambda d: _preprocess(d, algorithm=algorithm))
dataset = dataset.as_numpy_iterator()
return chunkify(dataset, chunk_length), specs.SPECS[algorithm]
116 changes: 116 additions & 0 deletions clrs/_src/dataset_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Unit tests for `dataset.py`."""

from typing import Generator, List

from absl.testing import absltest
from absl.testing import parameterized

from clrs._src import dataset
from clrs._src import samplers
from clrs._src import specs
import numpy as np

_Array = np.ndarray


def _stack_to_shortest(x: List[_Array]) -> _Array:
min_len = min(map(len, x))
return np.array([a[:min_len] for a in x])


def _make_sampler(algo: str) -> samplers.Sampler:
sampler, _ = samplers.build_sampler(
algo,
seed=samplers.CLRS30['val']['seed'],
num_samples=samplers.CLRS30['val']['num_samples'],
length=samplers.CLRS30['val']['length'],
)
return sampler


def _make_iterable_sampler(
algo: str, batch_size: int) -> Generator[samplers.Feedback, None, None]:
sampler = _make_sampler(algo)
while True:
yield sampler.next(batch_size)


class DatasetTest(parameterized.TestCase):

@parameterized.product(
name=specs.CLRS_30_ALGS[:5],
chunk_length=[20, 50])
def test_chunkify(self, name: str, chunk_length: int):
"""Test that samples are concatenated and split in chunks correctly."""
batch_size = 8

ds = _make_iterable_sampler(name, batch_size)
chunked_ds = dataset.chunkify(
_make_iterable_sampler(name, batch_size),
chunk_length)

samples = [next(ds) for _ in range(20)]
cum_lengths = np.cumsum([s.features.lengths for s in samples], axis=0)
n_chunks = np.amax(cum_lengths[-1]).astype(int) // chunk_length + 1
chunks = [next(chunked_ds) for _ in range(n_chunks)]

# Check correctness of `is_first` and `is_last` markers
start_idx = _stack_to_shortest([np.where(x)[0] for x in np.concatenate(
[c.features.is_first for c in chunks]).T]).T
end_idx = _stack_to_shortest([np.where(x)[0] for x in np.concatenate(
[c.features.is_last for c in chunks]).T]).T
assert len(start_idx) >= len(cum_lengths)
start_idx = start_idx[:len(cum_lengths)]
assert len(end_idx) >= len(cum_lengths)
end_idx = end_idx[:len(cum_lengths)]

np.testing.assert_equal(start_idx[0], 0)
np.testing.assert_array_equal(cum_lengths - 1, end_idx)
np.testing.assert_array_equal(cum_lengths[:-1], start_idx[1:])

# Check that inputs, outputs and hints have been copied correctly
all_input = np.concatenate([c.features.inputs[0].data for c in chunks])
all_output = np.concatenate([c.outputs[0].data for c in chunks])
all_hint = np.concatenate([c.features.hints[0].data for c in chunks])
for i in range(batch_size):
length0 = int(samples[0].features.lengths[i])
length1 = int(samples[1].features.lengths[i])
# Check first sample
np.testing.assert_array_equal(
all_input[:length0, i],
np.tile(samples[0].features.inputs[0].data[i], [length0, 1]))
np.testing.assert_array_equal(
all_output[:length0, i],
np.tile(samples[0].outputs[0].data[i], [length0, 1]))
np.testing.assert_array_equal(
all_hint[:length0, i],
samples[0].features.hints[0].data[:length0, i])
# Check second sample
np.testing.assert_array_equal(
all_input[length0:length0 + length1, i],
np.tile(samples[1].features.inputs[0].data[i], [length1, 1]))
np.testing.assert_array_equal(
all_output[length0:length0 + length1, i],
np.tile(samples[1].outputs[0].data[i], [length1, 1]))
np.testing.assert_array_equal(
all_hint[length0:length0 + length1, i],
samples[1].features.hints[0].data[:length1, i])


if __name__ == '__main__':
absltest.main()
2 changes: 2 additions & 0 deletions clrs/_src/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

Algorithm = Callable[..., Any]
Features = collections.namedtuple('Features', ['inputs', 'hints', 'lengths'])
FeaturesChunked = collections.namedtuple(
'Features', ['inputs', 'hints', 'is_first', 'is_last'])
Feedback = collections.namedtuple('Feedback', ['features', 'outputs'])

# CLRS-30 baseline spec.
Expand Down