diff --git a/src/io/image_iter_common.h b/src/io/image_iter_common.h index c9e3933ade28..10cd8ab4e5de 100644 --- a/src/io/image_iter_common.h +++ b/src/io/image_iter_common.h @@ -125,11 +125,13 @@ struct ImageRecParserParam : public dmlc::Parameter { bool verbose; /*! \brief partition the data into multiple parts */ int num_parts; - /*! \brief the index of the part will read*/ + /*! \brief the index of the part will read */ int part_index; - /*! \brief the size of a shuffle chunk*/ + /*! \brief device id used to create context for internal NDArray */ + int device_id; + /*! \brief the size of a shuffle chunk */ size_t shuffle_chunk_size; - /*! \brief the seed for chunk shuffling*/ + /*! \brief the seed for chunk shuffling */ int shuffle_chunk_seed; /*! \brief random seed for augmentations */ dmlc::optional seed_aug; @@ -163,6 +165,11 @@ struct ImageRecParserParam : public dmlc::Parameter { .describe("Virtually partition the data into these many parts."); DMLC_DECLARE_FIELD(part_index).set_default(0) .describe("The *i*-th virtual partition to be read."); + DMLC_DECLARE_FIELD(device_id).set_default(0) + .describe("The device id used to create context for internal NDArray. "\ + "Setting device_id to -1 will create Context::CPU(0). Setting " + "device_id to valid positive device id will create " + "Context::CPUPinned(device_id). Default is 0."); DMLC_DECLARE_FIELD(shuffle_chunk_size).set_default(0) .describe("The data shuffle buffer size in MB. Only valid if shuffle is true."); DMLC_DECLARE_FIELD(shuffle_chunk_seed).set_default(0) diff --git a/src/io/iter_image_recordio_2.cc b/src/io/iter_image_recordio_2.cc index 89f7753983db..00c38198659f 100644 --- a/src/io/iter_image_recordio_2.cc +++ b/src/io/iter_image_recordio_2.cc @@ -285,9 +285,14 @@ inline bool ImageRecordIOParser2::ParseNext(DataBatch *out) { shape_vec.push_back(param_.label_width); TShape label_shape(shape_vec.begin(), shape_vec.end()); - out->data.at(0) = NDArray(data_shape, Context::CPU(0), false, + auto ctx = Context::CPU(0); + auto dev_id = param_.device_id; + if (dev_id != -1) { + ctx = Context::CPUPinned(dev_id); + } + out->data.at(0) = NDArray(data_shape, ctx, false, mshadow::DataType::kFlag); - out->data.at(1) = NDArray(label_shape, Context::CPU(0), false, + out->data.at(1) = NDArray(label_shape, ctx, false, mshadow::DataType::kFlag); unit_size_[0] = param_.data_shape.Size(); unit_size_[1] = param_.label_width;