Skip to content

Commit

Permalink
improve image iter speed
Browse files Browse the repository at this point in the history
  • Loading branch information
haijieg committed Apr 4, 2016
1 parent 8c5bfa9 commit 9ceaef9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def _infer_column_shape(self, sarray):
else:
return (lengths.max(), )
elif dtype is gl.Image:
first_image = sarray.dropna()[0]
first_image = sarray.head(1)[0]
return (first_image.channels, first_image.height, first_image.width)

def infer_shape(self):
Expand Down
28 changes: 17 additions & 11 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ void NDArray::SyncCopyFromSFrame(const graphlab::flexible_type *data, size_t siz
auto type = data[0].get_type();
if (type == graphlab::flex_type_enum::IMAGE) {
CHECK_EQ(size, 1) << "Image data only support one input field";
graphlab::image_type img = data[0].get<graphlab::flex_image>();
const graphlab::image_type& img = data[0].get<graphlab::flex_image>();
mshadow::Tensor<cpu, 4> batch_tensor = dst.GetWithShape<cpu, 4, float>(
mshadow::Shape4(dshape[0], img.m_channels, img.m_height, img.m_width));

Expand All @@ -604,16 +604,22 @@ void NDArray::SyncCopyFromSFrame(const graphlab::flexible_type *data, size_t siz
} else if (img.m_format == graphlab::Format::PNG) {
graphlab::decode_png((const char*)img.get_image_data(), img.m_image_data_size, &buf, length);
}
img.m_image_data.reset(buf);
img.m_image_data_size = length;
img.m_format = graphlab::Format::RAW_ARRAY;
}
size_t cnt = 0;
const unsigned char* raw_data = img.get_image_data();
for (size_t i = 0; i < img.m_height; ++i) {
for (size_t j = 0; j < img.m_width; ++j) {
for (size_t k = 0; k < img.m_channels; ++k) {
batch_tensor[idx][k][i][j] = raw_data[cnt++];
size_t cnt = 0;
for (size_t i = 0; i < img.m_height; ++i) {
for (size_t j = 0; j < img.m_width; ++j) {
for (size_t k = 0; k < img.m_channels; ++k) {
batch_tensor[idx][k][i][j] = buf[cnt++];
}
}
}
} else {
size_t cnt = 0;
const unsigned char* raw_data = img.get_image_data();
for (size_t i = 0; i < img.m_height; ++i) {
for (size_t j = 0; j < img.m_width; ++j) {
for (size_t k = 0; k < img.m_channels; ++k) {
batch_tensor[idx][k][i][j] = raw_data[cnt++];
}
}
}
}
Expand Down

0 comments on commit 9ceaef9

Please sign in to comment.