Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Large Tensor] Fixed col2im op (#17622)
Browse files Browse the repository at this point in the history
* Changed dtype for index_im

* Added nightly test for col2im
  • Loading branch information
connorgoggins authored Feb 24, 2020
1 parent 3f0b049 commit f9b2a63
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/operator/nn/im2col.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ inline void im2col_nd_core_cpu(const DType* data_input, const bool im2col,
// Loop over spatial axes in forward order to compute the indices in the
// image and column, and whether the index lies in the padding.
index_t index_col = c_col;
int index_im = c_col / kernel_size;
index_t index_im = c_col / kernel_size;
bool is_padding = false;
for (index_t d_i = 0; d_i < num_spatial_axes; ++d_i) {
const index_t d = d_iter[d_i];
Expand All @@ -191,7 +191,7 @@ inline void im2col_nd_core_cpu(const DType* data_input, const bool im2col,
is_padding |= d_im < 0 || d_im >= static_cast<int>(im_shape[d_i + 2]);
index_col *= col_shape[d_i + 1];
index_col += d;
index_im *= static_cast<int>(im_shape[d_i + 2]);
index_im *= static_cast<index_t>(im_shape[d_i + 2]);
index_im += d_im;
}
if (im2col) {
Expand Down
14 changes: 14 additions & 0 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,19 @@ def npy_instance_norm(data, gamma, beta, axis, eps=1E-5):
assert_almost_equal(out, out_nd.asnumpy(), forward_check_eps,
forward_check_eps)

def check_col2im():
data = nd.random_normal(shape=(1, 2**30, 4))
output_size = (2, 2, 1)
kernel = (1, 1, 1)

res = nd.col2im(data=data, output_size=output_size, kernel=kernel)

assert res.shape[0] == 1
assert res.shape[1] == 1073741824
assert res.shape[2] == 2
assert res.shape[3] == 2
assert res.shape[4] == 1

check_gluon_embedding()
check_fully_connected()
check_dense()
Expand All @@ -474,6 +487,7 @@ def npy_instance_norm(data, gamma, beta, axis, eps=1E-5):
check_linear_and_logistic_regression()
check_l2_normalization()
check_instance_norm()
check_col2im()


def test_tensor():
Expand Down

0 comments on commit f9b2a63

Please sign in to comment.