diff --git a/python/mxnet/io.py b/python/mxnet/io.py index c07c7a2b8062..fdb32408d191 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -477,7 +477,7 @@ def _infer_column_shape(self, sarray): else: return (lengths.max(), ) elif dtype is gl.Image: - first_image = sarray.dropna()[0] + first_image = sarray.head(1)[0] return (first_image.channels, first_image.height, first_image.width) def infer_shape(self): diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 1a88d7528be8..23ca1cc8a674 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -584,7 +584,7 @@ void NDArray::SyncCopyFromSFrame(const graphlab::flexible_type *data, size_t siz auto type = data[0].get_type(); if (type == graphlab::flex_type_enum::IMAGE) { CHECK_EQ(size, 1) << "Image data only support one input field"; - graphlab::image_type img = data[0].get(); + const graphlab::image_type& img = data[0].get(); mshadow::Tensor batch_tensor = dst.GetWithShape( mshadow::Shape4(dshape[0], img.m_channels, img.m_height, img.m_width)); @@ -604,16 +604,22 @@ void NDArray::SyncCopyFromSFrame(const graphlab::flexible_type *data, size_t siz } else if (img.m_format == graphlab::Format::PNG) { graphlab::decode_png((const char*)img.get_image_data(), img.m_image_data_size, &buf, length); } - img.m_image_data.reset(buf); - img.m_image_data_size = length; - img.m_format = graphlab::Format::RAW_ARRAY; - } - size_t cnt = 0; - const unsigned char* raw_data = img.get_image_data(); - for (size_t i = 0; i < img.m_height; ++i) { - for (size_t j = 0; j < img.m_width; ++j) { - for (size_t k = 0; k < img.m_channels; ++k) { - batch_tensor[idx][k][i][j] = raw_data[cnt++]; + size_t cnt = 0; + for (size_t i = 0; i < img.m_height; ++i) { + for (size_t j = 0; j < img.m_width; ++j) { + for (size_t k = 0; k < img.m_channels; ++k) { + batch_tensor[idx][k][i][j] = buf[cnt++]; + } + } + } + } else { + size_t cnt = 0; + const unsigned char* raw_data = img.get_image_data(); + for (size_t i = 0; i < img.m_height; ++i) { + for (size_t j = 0; j < img.m_width; ++j) { + for (size_t k = 0; k < img.m_channels; ++k) { + batch_tensor[idx][k][i][j] = raw_data[cnt++]; + } } } }