diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h index f6eb45defcc1c..721e9bc69adac 100644 --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h @@ -58,21 +58,13 @@ struct ElementLT final { const uint64_t rank; }; -/// The type of callback functions which receive an element. -template -using ElementConsumer = - const std::function &, V)> &; - /// A memory-resident sparse tensor in coordinate-scheme representation -/// (a collection of `Element`s). This data structure is used as -/// an intermediate representation; e.g., for reading sparse tensors -/// from external formats into memory, or for certain conversions between -/// different `SparseTensorStorage` formats. +/// (a collection of `Element`s). This data structure is used as an +/// intermediate representation, e.g., for reading sparse tensors from +/// external formats into memory. template class SparseTensorCOO final { public: - using const_iterator = typename std::vector>::const_iterator; - /// Constructs a new coordinate-scheme sparse tensor with the given /// sizes and an optional initial storage capacity. explicit SparseTensorCOO(const std::vector &dimSizes, @@ -106,7 +98,7 @@ class SparseTensorCOO final { /// Returns the `operator<` closure object for the COO's element type. ElementLT getElementLT() const { return ElementLT(getRank()); } - /// Adds an element to the tensor. This method invalidates all iterators. + /// Adds an element to the tensor. void add(const std::vector &dimCoords, V val) { const uint64_t *base = coordinates.data(); const uint64_t size = coordinates.size(); @@ -135,12 +127,9 @@ class SparseTensorCOO final { elements.push_back(addedElem); } - const_iterator begin() const { return elements.cbegin(); } - const_iterator end() const { return elements.cend(); } - /// Sorts elements lexicographically by coordinates. If a coordinate /// is mapped to multiple values, then the relative order of those - /// values is unspecified. This method invalidates all iterators. + /// values is unspecified. void sort() { if (isSorted) return; diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h index 5e57facaf2376..c5be3d1acc337 100644 --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h @@ -37,6 +37,11 @@ namespace mlir { namespace sparse_tensor { +/// The type of callback functions which receive an element. +template +using ElementConsumer = + const std::function &, V)> &; + // Forward references. template class SparseTensorEnumeratorBase;