Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make dataset reader faster #2439

Merged
merged 11 commits into from
May 8, 2024
24 changes: 15 additions & 9 deletions ci_test/common_python/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,7 @@ def create_tests(setup_func,
test_file,
test_name_base=None,
skip_clusters=[],
post_test_func=None,
**kwargs):
"""Create functions that can interact with PyTest

Expand Down Expand Up @@ -724,6 +725,8 @@ def create_tests(setup_func,
cases, use `__file__`.
test_name (str, optional): Descriptive name (default: test
file name with '.py' removed).
post_test_func (function): Runs after the LBANN experiment if
successful.
**kwargs: Keyword arguments to pass into
`lbann.contrib.launcher.run`.

Expand Down Expand Up @@ -765,7 +768,8 @@ def create_tests(setup_func,

environment['MIOPEN_USER_DB_PATH'] = f'{db_path}/MIOpen_user_db'
# Empirically the cache dir cannot be on a parallel file system
environment['MIOPEN_CUSTOM_CACHE_DIR'] =f'{tmpdir}/MIOpen_custom_cache'
environment[
'MIOPEN_CUSTOM_CACHE_DIR'] = f'{tmpdir}/MIOpen_custom_cache'

kwargs['environment'] = environment

Expand All @@ -777,7 +781,8 @@ def test_func(cluster, dirname, weekly):
"""
test_name = '{}'.format(test_name_base)
if cluster in skip_clusters:
e = "test \"%s\" not supported on cluster \"%s\"" % (test_name, cluster)
e = "test \"%s\" not supported on cluster \"%s\"" % (test_name,
cluster)
print('Skip - ' + e)
pytest.skip(e)

Expand All @@ -786,7 +791,8 @@ def test_func(cluster, dirname, weekly):
import lbann.contrib.launcher

# Setup LBANN experiment
trainer, model, data_reader, optimizer, req_num_nodes = setup_func(lbann, weekly)
trainer, model, data_reader, optimizer, req_num_nodes = setup_func(
lbann, weekly)

if req_num_nodes:
kwargs['nodes'] = req_num_nodes
Expand All @@ -795,12 +801,12 @@ def test_func(cluster, dirname, weekly):
_kwargs = copy.deepcopy(kwargs)
if 'work_dir' not in _kwargs:
_kwargs['work_dir'] = os.path.join(os.path.dirname(test_file),
'experiments',
test_name)
'experiments', test_name)

# If the user provided a suffix for the work directory, append it
if 'work_subdir' in _kwargs:
_kwargs['work_dir'] = os.path.join(_kwargs['work_dir'], _kwargs['work_subdir'])
_kwargs['work_dir'] = os.path.join(_kwargs['work_dir'],
_kwargs['work_subdir'])
del _kwargs['work_subdir']

# Delete the work directory
Expand Down Expand Up @@ -828,6 +834,8 @@ def test_func(cluster, dirname, weekly):
**_kwargs,
)
assert_success(return_code, stderr_log_file)
if post_test_func is not None:
post_test_func(lbann, weekly)
return {
'return_code': return_code,
'work_dir': work_dir,
Expand All @@ -838,9 +846,7 @@ def test_func(cluster, dirname, weekly):
# Specific test functions name
test_func.__name__ = test_name_base

return (
test_func,
)
return (test_func, )


def create_python_data_reader(lbann,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,22 @@ class python_dataset_reader : public generic_data_reader

private:
void queue_epoch();
void queue_samples(uint64_t samples_to_queue);

/** @brief Path to the pickled dataset object. */
std::string m_dataset_path;
/** @brief Optional directory containing module with dataset definition. */
std::string m_module_dir;
/** @brief Number of samples to prefetch per worker. */
uint64_t m_prefetch_factor;
int m_prefetch_factor;
/** @brief Number of I/O threads. */
int m_num_io_threads;
/** @brief The current dataset shuffled minibatch offset. */
uint64_t m_dataset_minibatch_offset;
/** @brief The current dataset shuffled sample offset. */
uint64_t m_dataset_sample_offset;
/** @brief Number of samples requested this epoch. */
uint64_t m_queued_samples;
/** @brief Dimensions of data sample tensor. */
std::vector<El::Int> m_sample_dims;
/** @brief Size of label tensor. */
Expand Down
67 changes: 29 additions & 38 deletions python/lbann/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,8 @@ class DataReader:
Helper class used by LBANN to control worker processes and handle sample/batch loading.
"""

def __init__(
self, dataset: Dataset, num_procs: int, prefetch_factor: int, dtype: str
) -> None:
def __init__(self, dataset: Dataset, num_procs: int, prefetch_factor: int,
dtype: str) -> None:
"""
DataReader Constructor

Expand All @@ -184,16 +183,17 @@ def __init__(
self.dtype = dtype
self.sample_dims = dataset.get_sample_dims()
self.num_io_partitions = 1
self.queued_indices = []
self.loaded_samples = []

if isinstance(self.dataset, DistConvDataset):
self.num_io_partitions = self.dataset.num_io_partitions

self.pool = Pool(processes=num_procs, initializer=DataReader.init_worker)
self.pool = Pool(processes=num_procs,
initializer=DataReader.init_worker,
initargs=(self.dataset, ))

@staticmethod
def init_worker():
def init_worker(dataset):
"""
Initialize worker process.

Expand All @@ -209,16 +209,21 @@ def init_worker():
except:
pass

# Process-local storage
global g_dataset
g_dataset = dataset

def terminate(self) -> None:
"""
Terminate all worker processes.
"""
self.pool.terminate()

@staticmethod
def load_sample(dataset, ind) -> None:
def load_sample(ind) -> Sample:
"""
Loads the sample from the dataset at the specified index.
This function must be called from a worker process.

:param dataset: Dataset
:type dataset: Dataset
Expand All @@ -227,32 +232,25 @@ def load_sample(dataset, ind) -> None:
:return: Sample
:rtype: Sample
"""
return dataset[ind]
return g_dataset[ind]

def load_next_sample_async(self):
def load_next_sample_async(self, ind: int):
"""
Submit the next sample index to be loaded to the worker pool.
"""
self.loaded_samples.append(
self.pool.apply_async(
DataReader.load_sample, (self.dataset, self.queued_indices.pop(0))
)
)
self.pool.apply_async(DataReader.load_sample, (ind, )))

def queue_epoch(self, inds: List[int]) -> None:
def queue_samples(self, inds: List[int]) -> None:
"""
Set the indices to be loaded this epoch and start submitting jobs
to the worker pool.

:param inds: List of sample indices
:type inds: List[int]
"""
self.queued_indices += inds
while (
len(self.loaded_samples) < self.num_procs * self.prefetch_factor
and len(self.queued_indices) > 0
):
self.load_next_sample_async()
for ind in inds:
self.load_next_sample_async(ind)

def get_batch(self, batch_size: int) -> Dict[str, Union[np.ndarray, int]]:
"""
Expand All @@ -266,38 +264,31 @@ def get_batch(self, batch_size: int) -> Dict[str, Union[np.ndarray, int]]:
samples = []
for _ in range(batch_size):
samples.append(self.loaded_samples.pop(0).get())
if len(self.queued_indices) > 0:
self.load_next_sample_async()

batch = {}

# Note: we return the arrays with the pointers so that they aren't
# deallocated by the garbage collector.
batch["sample"] = np.ascontiguousarray(
[s.sample for s in samples], dtype=self.dtype
)
batch["sample"] = np.ascontiguousarray([s.sample for s in samples],
dtype=self.dtype)
batch["sample_ptr"] = batch["sample"].ctypes.data
assert (
batch["sample"].size
== np.prod(self.sample_dims.sample) * batch_size / self.num_io_partitions
)
assert (batch["sample"].size == np.prod(self.sample_dims.sample) *
batch_size / self.num_io_partitions)

if hasattr(self.sample_dims, "label"):
batch["label"] = np.ascontiguousarray(
[s.label for s in samples], dtype=self.dtype
)
batch["label"] = np.ascontiguousarray([s.label for s in samples],
dtype=self.dtype)
batch["label_ptr"] = batch["label"].ctypes.data
assert batch["label"].size == np.prod(self.sample_dims.label) * batch_size
assert batch["label"].size == np.prod(
self.sample_dims.label) * batch_size

if hasattr(self.sample_dims, "response"):
batch["response"] = np.ascontiguousarray(
[s.response for s in samples], dtype=self.dtype
)
[s.response for s in samples], dtype=self.dtype)
batch["response_ptr"] = batch["response"].ctypes.data
assert (
batch["response"].size
== np.prod(self.sample_dims.response) * batch_size
)
batch["response"].size == np.prod(self.sample_dims.response) *
batch_size)

return batch

Expand Down
66 changes: 53 additions & 13 deletions src/data_ingestion/readers/data_reader_python_dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ bool python_dataset_reader::fetch_data_block(
// Acquire Python GIL on first IO thread
python::global_interpreter_lock gil;

// If not enough samples were queued when the epoch began, queue the rest
if (m_queued_samples < mb_size) {
queue_samples(mb_size - m_queued_samples);
}

// Check that shared memory array is large enough
uint64_t num_io_partitions = 1;
#ifdef LBANN_HAS_DISTCONV
Expand Down Expand Up @@ -155,6 +160,9 @@ bool python_dataset_reader::fetch_data_block(
El::Copy(response_shared_memory_matrix, Y);
}

// Prefetch the next minibatch asynchronously
this->queue_samples(mb_size);

return true;
}

Expand Down Expand Up @@ -247,6 +255,7 @@ void python_dataset_reader::setup(int num_io_threads,
observer_ptr<thread_pool> io_thread_pool)
{
generic_data_reader::setup(num_io_threads, io_thread_pool);
m_num_io_threads = num_io_threads;

// Acquire Python GIL
python::global_interpreter_lock gil;
Expand Down Expand Up @@ -275,34 +284,65 @@ void python_dataset_reader::setup(int num_io_threads,
queue_epoch();
}

void python_dataset_reader::queue_epoch()
void python_dataset_reader::queue_samples(uint64_t samples_to_queue)
{
// Acquire Python GIL
python::global_interpreter_lock gil;

// NOTE: ASSUMES GIL IS ALREADY TAKEN
execution_mode mode = exec_mode_from_string(get_role());
dataset& ds = get_trainer().get_data_coordinator().get_dataset(mode);

// Get shuffled indices to be fetched by worker processes
python::object inds_list = PyList_New(0);
uint64_t num_samples = m_num_samples;
uint64_t base_offset = ds.get_base_offset();
uint64_t sample_stride = ds.get_sample_stride();
uint64_t mini_batch_stride = ds.get_stride_to_next_mini_batch();
for (uint64_t i = base_offset; i < num_samples; i += mini_batch_stride) {
for (uint64_t j = i;
j < std::min(num_samples, i - base_offset + mini_batch_stride);
j += sample_stride) {
PyList_Append(inds_list,
python::object(PyLong_FromLong(m_shuffled_indices[j])));
uint64_t base_offset = ds.get_base_offset();

for (uint64_t i = 0; i < samples_to_queue; ++i) {
uint64_t sample_ind = base_offset +
m_dataset_minibatch_offset * mini_batch_stride +
m_dataset_sample_offset * sample_stride;

// We went over the entire epoch
if (sample_ind >= num_samples)
break;

PyList_Append(
inds_list,
python::object(PyLong_FromLong(m_shuffled_indices[sample_ind])));

++m_dataset_sample_offset;
++m_queued_samples;

// Cycle minibatch offset
if (m_dataset_sample_offset * sample_stride + base_offset >=
mini_batch_stride) {
m_dataset_sample_offset = 0;
++m_dataset_minibatch_offset;
}
}

python::object(
PyObject_CallMethod(m_data_reader, "queue_epoch", "(O)", inds_list.get()));
python::object(PyObject_CallMethod(m_data_reader,
"queue_samples",
"(O)",
inds_list.get()));
python::check_error();
}

void python_dataset_reader::queue_epoch()
{
// Acquire Python GIL
python::global_interpreter_lock gil;

// Resets the sample offset to the beginning of the epoch
m_dataset_minibatch_offset = 0;
m_dataset_sample_offset = 0;
m_queued_samples = 0;

// Prefetch the first set of samples (if less than minibatch size, the first
// minibatch read will take care of the rest)
queue_samples(m_prefetch_factor * m_num_io_threads);
}

void python_dataset_reader::load()
{
// Make sure Python is running and acquire GIL
Expand Down
Loading