diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 97acced29d6e..df0357369fed 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -330,7 +330,7 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs, : param.axis); CHECK_LT(channelAxis, dshape.ndim()) << "Channel axis out of range: " << param.axis; - const int channelCount = dshape[channelAxis]; + const index_t channelCount = dshape[channelAxis]; if (!mxnet::ndim_is_known(dshape)) { return false; diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index d385b93e9cff..e3d641af4015 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -47,7 +47,7 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs, CHECK(axis >= 0 && axis < dshape.ndim()) << "Channel axis out of range: axis=" << param.axis; - const int channelCount = dshape[axis]; + const index_t channelCount = dshape[axis]; if (!mxnet::ndim_is_known(dshape)) { return false;