@@ -795,6 +795,199 @@ void OneHotOpForward(const nnvm::NodeAttrs& attrs,
795
795
});
796
796
}
797
797
798
+ /* !
799
+ * \brief sparse retain namespace
800
+ */
801
+ namespace sr {
802
+ enum SparseRetainOpInputs {kArr , kIdx };
803
+ enum SparseRetainOpOutputs {kOut };
804
+ } // namespace sr
805
+
806
+ inline bool SparseRetainOpShape (const nnvm::NodeAttrs& attrs,
807
+ std::vector<TShape> *in_attrs,
808
+ std::vector<TShape> *out_attrs) {
809
+ CHECK_EQ (in_attrs->size (), 2U )
810
+ << " sparse_retain operator takes 2 arguments (" << in_attrs->size () << " given)" ;
811
+ CHECK_EQ (out_attrs->size (), 1U );
812
+
813
+ TShape tshape ((*in_attrs)[sr::kArr ]);
814
+ shape_assign (&tshape, (*out_attrs)[sr::kOut ]);
815
+ SHAPE_ASSIGN_CHECK (*in_attrs, sr::kArr , tshape);
816
+ SHAPE_ASSIGN_CHECK (*out_attrs, sr::kOut , tshape);
817
+ return true ;
818
+ }
819
+
820
+ inline bool SparseRetainOpType (const nnvm::NodeAttrs& attrs,
821
+ std::vector<int > *in_attrs,
822
+ std::vector<int > *out_attrs) {
823
+ CHECK_EQ (in_attrs->size (), 2U );
824
+ CHECK_EQ (out_attrs->size (), 1U );
825
+ CHECK_NE ((*in_attrs)[sr::kIdx ], -1 ) << " Index type must be set for sparse_retain operator" ;
826
+
827
+ TYPE_ASSIGN_CHECK (*out_attrs, 0 , (*in_attrs)[sr::kArr ]);
828
+ TYPE_ASSIGN_CHECK (*in_attrs, 0 , (*out_attrs)[sr::kOut ]);
829
+ return (*in_attrs)[0 ] != -1 ;
830
+ }
831
+
832
+ inline bool SparseRetainForwardInferStorageType (const nnvm::NodeAttrs& attrs,
833
+ std::vector<int > *in_attrs,
834
+ std::vector<int > *out_attrs) {
835
+ CHECK_EQ (in_attrs->size (), 2U );
836
+ CHECK_EQ (out_attrs->size (), 1U );
837
+ if (kRowSparseStorage == in_attrs->at (sr::kArr )) {
838
+ out_attrs->at (sr::kOut ) = kRowSparseStorage ;
839
+ }
840
+ return true ;
841
+ }
842
+
843
+ inline bool SparseRetainBackwardInferStorageType (const nnvm::NodeAttrs& attrs,
844
+ std::vector<int > *in_attrs,
845
+ std::vector<int > *out_attrs) {
846
+ CHECK_EQ (in_attrs->size (), 2U );
847
+ CHECK_EQ (out_attrs->size (), 2U );
848
+ out_attrs->at (sr::kArr ) = kRowSparseStorage ;
849
+ out_attrs->at (sr::kIdx ) = kDefaultStorage ;
850
+ return true ;
851
+ }
852
+
853
+ struct SparseRetainRspForward {
854
+ template <typename DType, typename RType, typename IType>
855
+ MSHADOW_XINLINE static void Map (int i, DType* out_data, RType* out_idx,
856
+ const DType* in_data, const RType* in_idx,
857
+ const IType* idx, const size_t nnr,
858
+ const size_t num_cols) {
859
+ const RType irow = idx[i];
860
+ int j = -1 , left = 0 , right = nnr - 1 ;
861
+ while (left <= right) {
862
+ int m = left + (right - left) / 2 ;
863
+ const auto in_idx_m = in_idx[m];
864
+ if (in_idx_m == irow) {
865
+ j = m;
866
+ break ;
867
+ } else if (in_idx_m < irow) {
868
+ left = m + 1 ;
869
+ } else {
870
+ right = m - 1 ;
871
+ }
872
+ }
873
+ out_idx[i] = idx[i];
874
+ if (j >= 0 ) {
875
+ const size_t in_offset = j * num_cols;
876
+ const size_t out_offset = i * num_cols;
877
+ for (size_t k = 0 ; k < num_cols; ++k) {
878
+ out_data[out_offset+k] = in_data[in_offset+k];
879
+ }
880
+ }
881
+ }
882
+ };
883
+
884
+ template <typename xpu>
885
+ void SparseRetainOpForwardEx (const nnvm::NodeAttrs& attrs,
886
+ const OpContext& ctx,
887
+ const std::vector<NDArray>& inputs,
888
+ const std::vector<OpReqType>& req,
889
+ const std::vector<NDArray>& outputs) {
890
+ CHECK_EQ (inputs.size (), 2U );
891
+ CHECK_EQ (outputs.size (), 1U );
892
+ CHECK_EQ (req.size (), 1U );
893
+ CHECK_EQ (req[sr::kOut ], kWriteTo ) << " sparse_retain only supports req=\' write\' " ;
894
+
895
+ CHECK_EQ (inputs[sr::kArr ].storage_type (), kRowSparseStorage )
896
+ << " sparse_retain operator only takes row sparse NDArray as input" ;
897
+ CHECK_EQ (inputs[sr::kIdx ].storage_type (), kDefaultStorage )
898
+ << " sparse_retain operator only takes default NDArray as its index array" ;
899
+ CHECK_EQ (outputs[sr::kOut ].storage_type (), kRowSparseStorage )
900
+ << " sparse_retain operator only outputs row sparse NDArray" ;
901
+
902
+ const NDArray& input_nd = inputs[sr::kArr ];
903
+ const TBlob idx_data = inputs[sr::kIdx ].data ();
904
+
905
+ if (req[sr::kOut ] == kNullOp
906
+ || !input_nd.storage_initialized ()
907
+ || idx_data.Size () == 0U ) return ;
908
+
909
+ const TBlob input_data = input_nd.data ();
910
+ if (input_data.shape_ [0 ] == 0 ) return ;
911
+ const TBlob input_idx = input_nd.aux_data (rowsparse::kIdx );
912
+
913
+ NDArray output_nd = outputs[sr::kOut ];
914
+ output_nd.CheckAndAlloc ({mshadow::Shape1 (idx_data.Size ())});
915
+ TBlob output_data = output_nd.data ();
916
+ TBlob output_idx = output_nd.aux_data (rowsparse::kIdx );
917
+
918
+ using namespace mxnet_op ;
919
+ Stream<xpu> *s = ctx.get_stream <xpu>();
920
+ MSHADOW_TYPE_SWITCH (output_data.type_flag_ , DType, { // output data type
921
+ MSHADOW_INT_TYPE_SWITCH (output_idx.type_flag_ , RType, { // row index data type
922
+ MSHADOW_TYPE_SWITCH (idx_data.type_flag_ , IType, { // index array data type
923
+ Kernel<set_zero, xpu>::Launch (s, output_data.Size (), output_data.dptr <DType>());
924
+ Kernel<SparseRetainRspForward, xpu>::Launch (s, idx_data.Size (), output_data.dptr <DType>(),
925
+ output_idx.dptr <RType>(), input_data.dptr <DType>(), input_idx.dptr <RType>(),
926
+ idx_data.dptr <IType>(), input_data.shape_ [0 ], input_data.shape_ [1 ]);
927
+ });
928
+ });
929
+ });
930
+ }
931
+
932
+ template <int req>
933
+ struct SparseRetainRspBackward {
934
+ template <typename DType, typename RType, typename IType>
935
+ MSHADOW_XINLINE static void Map (int i, DType* in_grad, RType* in_grad_idx,
936
+ const DType* out_grad, const IType* idx,
937
+ const size_t num_cols) {
938
+ const RType irow = idx[i];
939
+ in_grad_idx[i] = irow;
940
+ const size_t out_offset = irow * num_cols;
941
+ const size_t in_offset = i * num_cols;
942
+ for (size_t j = 0 ; j < num_cols; ++j) {
943
+ KERNEL_ASSIGN (in_grad[in_offset+j], req, out_grad[out_offset+j]);
944
+ }
945
+ }
946
+ };
947
+
948
+ template <typename xpu>
949
+ void SparseRetainOpBackwardEx (const nnvm::NodeAttrs& attrs,
950
+ const OpContext& ctx,
951
+ const std::vector<NDArray>& inputs,
952
+ const std::vector<OpReqType>& req,
953
+ const std::vector<NDArray>& outputs) {
954
+ CHECK_EQ (inputs.size (), 2U );
955
+ CHECK_EQ (outputs.size (), 2U );
956
+ CHECK_EQ (req.size (), 2U );
957
+ CHECK_NE (req[sr::kArr ], kWriteInplace );
958
+ CHECK_EQ (req[sr::kIdx ], kNullOp )
959
+ << " sparse_retain does not support calculating gradients of indices" ;
960
+
961
+ CHECK_EQ (inputs[sr::kOut ].storage_type (), kDefaultStorage )
962
+ << " sparse_retain backward only takes default NDArray as ograd" ;
963
+ CHECK_EQ (inputs[sr::kIdx ].storage_type (), kDefaultStorage )
964
+ << " sparse_retain backward only takes default NDArray as its index array" ;
965
+ CHECK_EQ (outputs[sr::kArr ].storage_type (), kRowSparseStorage )
966
+ << " sparse_retain backward only outputs row sparse NDArray as grad of input" ;
967
+
968
+ const TBlob out_grad_data = inputs[sr::kOut ].data ();
969
+ const TBlob idx_data = inputs[sr::kIdx ].data ();
970
+
971
+ NDArray in_grad_nd = outputs[sr::kArr ];
972
+ in_grad_nd.CheckAndAlloc ({mshadow::Shape1 (idx_data.Size ())});
973
+ TBlob in_grad_data = in_grad_nd.data ();
974
+ TBlob in_grad_idx = in_grad_nd.aux_data (rowsparse::kIdx );
975
+
976
+ using namespace mxnet_op ;
977
+ Stream<xpu> *s = ctx.get_stream <xpu>();
978
+ MSHADOW_TYPE_SWITCH (out_grad_data.type_flag_ , DType, { // output data type
979
+ MSHADOW_INT_TYPE_SWITCH (in_grad_idx.type_flag_ , RType, { // row index data type
980
+ MSHADOW_TYPE_SWITCH (idx_data.type_flag_ , IType, { // index array data type
981
+ MXNET_ASSIGN_REQ_SWITCH (req[sr::kArr ], req_type, {
982
+ Kernel<SparseRetainRspBackward<req_type>, xpu>::Launch (
983
+ s, in_grad_idx.Size (), in_grad_data.dptr <DType>(), in_grad_idx.dptr <RType>(),
984
+ out_grad_data.dptr <DType>(), idx_data.dptr <IType>(), out_grad_data.shape_ [1 ]);
985
+ });
986
+ });
987
+ });
988
+ });
989
+ }
990
+
798
991
} // namespace op
799
992
} // namespace mxnet
800
993
#ifdef __CUDACC__
0 commit comments