diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 262f4692d2b2..da1ab4753761 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3239,16 +3239,29 @@ def get_tensor_expr(self, tensor, is_sparse=False): # pylint: disable=no-else-return def prepare_dense_matrix_from_sparse(sparse_tensor, sparse_tensor_value, sparse_tensor_type): """ Prepare sparse indices and dense matrix from TFLite sparse parameters. """ + # The function is implemented based on TFLite sparse parameter specifications + # Please refer + # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs#L89 + # for details about each parameters sparsity = sparse_tensor.Sparsity() dense_shape = sparse_tensor.ShapeAsNumpy() orig_rank = len(dense_shape) + + # The traversal order of the dimensions defined in the `shape` field of the to be dense tensor. traversal_order = sparsity.TraversalOrderAsNumpy() + + # For an n-dimensional tensor with a k-dimensional block (0 <= k <= n), + # stores how a block dimension in (dn, ..., dn+k-1) maps to the original + # tensor dimension in (d0, ..., dn). It's stored in the order of (dn, ..., dn+k-1). + # If not block-sparse, this field is NULL. block_map = sparsity.BlockMapAsNumpy() + total_rank = sparsity.TraversalOrderLength() dense_mat = np.full(shape=dense_shape, fill_value=0, dtype=sparse_tensor_type).flatten() from enum import Enum + # NOTE: Here the Vector term is borrowed from TFLite spec. class VectorType(Enum): Empty = 0 Int32 = 1 @@ -3274,8 +3287,23 @@ def _get_flattened_index(indices, shape): sub_elements *= shape[i] return index + # DimensionMetadata per dimension: the metadata needed for + # each dimension to locate the non-zero values in the original dense tensor + # inline with traversal order parameter. + # + # sp_format has 2 possible values: {DENSE = 0, SPARSE_CSR = 1} + # If format = DENSE{0} : DenseSize represents size of that dimension + # If format = SPARSE_CSR{1} : array_segments represents how to segment the indices array, + # each segment corresponds to one element in the previous dimension. array_indices + # represents the index of the non-zero elements within this dimension + # (as those in the CSR matrix format, where the first array is row pointers + # and the second array is column indices). sp_format = np.zeros(sparsity.DimMetadataLength()) dim_metadata = [None] * (2 * sparsity.DimMetadataLength()) + + # Below loop will fetch all meta data per dimension based on format type + # Dense or Sparse and will put it in an agnostic array for easy access + # while preparing dense buffer or indices. for i in range(sparsity.DimMetadataLength()): sp_format[i] = sparsity.DimMetadata(i).Format() if sp_format[i] == 0: @@ -3301,6 +3329,7 @@ def _get_flattened_index(indices, shape): block_dim = 0 block_size = np.zeros(sparsity.BlockMapLength()) + # Block size parameter if encoded in BSR format for i in range(orig_rank): if block_dim < sparsity.BlockMapLength() and block_map[block_dim] == i: orig_dim = traversal_order[orig_rank + block_dim] @@ -3309,6 +3338,8 @@ def _get_flattened_index(indices, shape): indices_list = [] + # Below function iterates through each applicable indices per dimension + # based on format type specified and finaly produce the dense matrix and the NZ indices. def _def_prepare_dense_matrix_from_sparse(indices, level, prev_idx): if level == len(indices): start_pos = 0