diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 4ddb6bb55491..403d606acd5a 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -7,6 +7,7 @@ #define MXNET_OPERATOR_TENSOR_MATRIX_OP_INL_H_ #include +#include #include #include #include @@ -683,9 +684,8 @@ struct DotCsrTransDnsRspByRowBlocks { if (col_idx < seg_start || col_idx >= seg_end) continue; const size_t offset_out = col_idx * num_cols; row_idx[col_idx] = 1; - const auto val = data_l[k]; for (size_t l = 0; l < num_cols; ++l) { - out[offset_out+l] += data_r[offset_r+l] * val; + out[offset_out+l] += data_r[offset_r+l] * data_l[k]; } } } @@ -903,6 +903,7 @@ void DotCsrDnsRspImpl(const OpContext& ctx, nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); + LOG(INFO) << "DotCsrDnsRspImpl: output storage shape = " << ret->storage_shape(); if (0 == nnr) return; mshadow::Tensor rsp_data = data_out.FlatTo2D(s); size_t idx = 0; @@ -942,6 +943,7 @@ void DotCsrRspDnsImpl(const OpContext& ctx, if (kNullOp == req) return; CHECK_EQ(lhs.storage_type(), kCSRStorage); CHECK_EQ(rhs.storage_type(), kRowSparseStorage); + LOG(INFO) << "DotCsrRspDnsImpl: rhs storage shape = " << rhs.storage_shape(); mshadow::Stream *s = ctx.get_stream(); if (!lhs.storage_initialized() || !rhs.storage_initialized()) { if (kWriteTo == req) { @@ -1007,6 +1009,7 @@ void DotCsrRspRspImpl(const OpContext& ctx, if (kNullOp == req) return; CHECK_EQ(lhs.storage_type(), kCSRStorage); CHECK_EQ(rhs.storage_type(), kRowSparseStorage); + LOG(INFO) << "DotCsrRspRspImpl: rhs storage shape = " << rhs.storage_shape(); CHECK_EQ(ret->storage_type(), kRowSparseStorage); if (!lhs.storage_initialized() || !rhs.storage_initialized()) return; @@ -1048,6 +1051,7 @@ void DotCsrRspRspImpl(const OpContext& ctx, nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); + LOG(INFO) << "DotCsrRspRspImpl: output storage shape = " << ret->storage_shape(); if (0 == nnr) return; mshadow::Tensor rsp_data = data_out.FlatTo2D(s); size_t idx = 0; @@ -1129,19 +1133,31 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs, auto out_stype = outputs[0].storage_type(); if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kDefaultStorage) { TBlob ret = outputs[0].data(); + double start = dmlc::GetTime(); DotCsrDnsDnsImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &ret); + double elapse = dmlc::GetTime() - start; + LOG(INFO) << "DotCsrDnsDnsImpl: trans_lhs = " << param.transpose_a << ", time cost: " << elapse * 1000 << " ms"; } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage && out_stype == kDefaultStorage) { TBlob ret = outputs[0].data(); + double start = dmlc::GetTime(); DotCsrRspDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + double elapse = dmlc::GetTime() - start; + LOG(INFO) << "DotCsrRspDnsImpl: trans_lhs = " << param.transpose_a << ", time cost: " << elapse * 1000 << " ms"; } else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kRowSparseStorage) { NDArray out = outputs[0]; + double start = dmlc::GetTime(); DotCsrDnsRspImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &out); + double elapse = dmlc::GetTime() - start; + LOG(INFO) << "DotCsrDnsRspImpl: trans_lhs = " << param.transpose_a << ", time cost: " << elapse * 1000 << " ms"; } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage && out_stype == kRowSparseStorage) { NDArray ret = outputs[0]; + double start = dmlc::GetTime(); DotCsrRspRspImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + double elapse = dmlc::GetTime() - start; + LOG(INFO) << "DotCsrRspRspImpl: trans_lhs = " << param.transpose_a << ", time cost: " << elapse * 1000 << " ms"; } else { FCompExFallback(attrs, ctx, inputs, req, outputs, DotForward_, "DotForward_"); }