Skip to content

Commit

Permalink
Use CPUPinned context in ImageRecordIOParser2 (apache#13980)
Browse files Browse the repository at this point in the history
* create NDArray with CPUPinned context in ImageRecordIOParser2

* update document

* use -1 device_id as an option to create CPU(0) context

* retrigger CI

* fix cpplint error
  • Loading branch information
yuxihu authored and haohuw committed Jun 23, 2019
1 parent b191aee commit 8ca025c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
13 changes: 10 additions & 3 deletions src/io/image_iter_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,13 @@ struct ImageRecParserParam : public dmlc::Parameter<ImageRecParserParam> {
bool verbose;
/*! \brief partition the data into multiple parts */
int num_parts;
/*! \brief the index of the part will read*/
/*! \brief the index of the part will read */
int part_index;
/*! \brief the size of a shuffle chunk*/
/*! \brief device id used to create context for internal NDArray */
int device_id;
/*! \brief the size of a shuffle chunk */
size_t shuffle_chunk_size;
/*! \brief the seed for chunk shuffling*/
/*! \brief the seed for chunk shuffling */
int shuffle_chunk_seed;
/*! \brief random seed for augmentations */
dmlc::optional<int> seed_aug;
Expand Down Expand Up @@ -163,6 +165,11 @@ struct ImageRecParserParam : public dmlc::Parameter<ImageRecParserParam> {
.describe("Virtually partition the data into these many parts.");
DMLC_DECLARE_FIELD(part_index).set_default(0)
.describe("The *i*-th virtual partition to be read.");
DMLC_DECLARE_FIELD(device_id).set_default(0)
.describe("The device id used to create context for internal NDArray. "\
"Setting device_id to -1 will create Context::CPU(0). Setting "
"device_id to valid positive device id will create "
"Context::CPUPinned(device_id). Default is 0.");
DMLC_DECLARE_FIELD(shuffle_chunk_size).set_default(0)
.describe("The data shuffle buffer size in MB. Only valid if shuffle is true.");
DMLC_DECLARE_FIELD(shuffle_chunk_seed).set_default(0)
Expand Down
9 changes: 7 additions & 2 deletions src/io/iter_image_recordio_2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,14 @@ inline bool ImageRecordIOParser2<DType>::ParseNext(DataBatch *out) {
shape_vec.push_back(param_.label_width);
TShape label_shape(shape_vec.begin(), shape_vec.end());

out->data.at(0) = NDArray(data_shape, Context::CPU(0), false,
auto ctx = Context::CPU(0);
auto dev_id = param_.device_id;
if (dev_id != -1) {
ctx = Context::CPUPinned(dev_id);
}
out->data.at(0) = NDArray(data_shape, ctx, false,
mshadow::DataType<DType>::kFlag);
out->data.at(1) = NDArray(label_shape, Context::CPU(0), false,
out->data.at(1) = NDArray(label_shape, ctx, false,
mshadow::DataType<real_t>::kFlag);
unit_size_[0] = param_.data_shape.Size();
unit_size_[1] = param_.label_width;
Expand Down

0 comments on commit 8ca025c

Please sign in to comment.