diff --git a/guide/basic.cpp b/guide/basic.cpp index ce03be057dd7..8e90ac966053 100644 --- a/guide/basic.cpp +++ b/guide/basic.cpp @@ -140,6 +140,22 @@ int main(void) { printf("\n"); } + printf("mask\n"); + TensorContainer mask_data(Shape2(6, 8)); + TensorContainer mask_out(Shape2(6, 8)); + TensorContainer mask_src(Shape1(6)); + + mask_data = 1.0f; + for (int i = 0; i < 6; ++i) { + mask_src[i] = static_cast(i); + } + mask_out = mask(mask_src, mask_data); + for (index_t i = 0; i < mask_out.size(0); ++i) { + for (index_t j = 0; j < mask_out.size(1); ++j) { + printf("%.2f ", mask_out[i][j]); + } + printf("\n"); + } ShutdownTensorEngine(); return 0; } diff --git a/mshadow/extension.h b/mshadow/extension.h index 4d25fa7d7d59..7af0f56f7699 100644 --- a/mshadow/extension.h +++ b/mshadow/extension.h @@ -37,4 +37,5 @@ #include "./extension/flip.h" #include "./extension/complex.h" #include "./extension/range.h" +#include "./extension/mask.h" #endif // MSHADOW_EXTENSION_H_ diff --git a/mshadow/extension/mask.h b/mshadow/extension/mask.h new file mode 100644 index 000000000000..0fd4cc6db72e --- /dev/null +++ b/mshadow/extension/mask.h @@ -0,0 +1,97 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file mask.h + * \brief + * \author Bing Xu +*/ +#ifndef MSHADOW_EXTENSION_MASK_H_ +#define MSHADOW_EXTENSION_MASK_H_ + +#include "../extension.h" + +namespace mshadow { +namespace expr { + +/*! \brief Broadcast a mask and do element-wise multiplication + * \tparam IndexExp type of index expression + * \tparam SrcExp type of src expression + * \tparam DType data type + */ +template +struct MaskExp: public Exp, + DType, type::kChainer> { + /*! \brief index oprand */ + const IndexExp &index_; + /*! \brief matrix oprand */ + const SrcExp &src_; + /*! constructor */ + MaskExp(const IndexExp &index, const SrcExp &src) + : index_(index), src_(src) {} +}; // struct MaskExp + + + +template +inline MaskExp +mask(const Exp &index, + const Exp &src) { + return MaskExp(index.self(), src.self()); +} + + +//---------------------- +// Execution plan +//---------------------- + +template +struct Plan, DType> { + public: + explicit Plan(const MaskExp &e) + : index_(MakePlan(e.index_)), src_(MakePlan(e.src_)) { + } + + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + return static_cast(src_.Eval(y, x) * index_.Eval(0, y)); + } + + private: + expr::Plan index_; + expr::Plan src_; +}; // struct Plan + +template +inline Plan, DType> +MakePlan(const MaskExp &exp) { + return Plan, DType>(exp); +} + +template +struct ShapeCheck > { + inline static Shape + Check(const MaskExp &t) { + CHECK(dim == 2) + << "MaskExp only support 2D output"; + Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_); + Shape<2> wshape = ShapeCheck<2, SrcExp>::Check(t.src_); + CHECK_EQ(dshape[0], wshape[0]) << "MaskExp require inputs in same first dimention"; + Shape ret; + ret[0] = wshape[0]; + ret[1] = wshape[1]; + return ret; + } +}; + + +template +struct ExpInfo > { + static const int kDim = 2; + static const int kDevMask = ExpInfo::kDevMask; +}; + +} // namespace expr +} // namespace mshadow + +#endif // MSHADOW_EXTENSION_MASK_H_