diff --git a/Src/ILGPU.Algorithms/MatrixOperations/MaskedMatrixProcessor.cs b/Src/ILGPU.Algorithms/MatrixOperations/MaskedMatrixProcessor.cs new file mode 100644 index 000000000..592ff2694 --- /dev/null +++ b/Src/ILGPU.Algorithms/MatrixOperations/MaskedMatrixProcessor.cs @@ -0,0 +1,114 @@ +// --------------------------------------------------------------------------------------- +// ILGPU Algorithms +// Copyright (c) 2023 ILGPU Project +// www.ilgpu.net +// +// File: MaskedMatrixProcessor.cs +// +// This file is part of ILGPU and is distributed under the University of Illinois Open +// Source License. See LICENSE.txt for details. +// --------------------------------------------------------------------------------------- + +using ILGPU.Runtime; +using ILGPU.Util; +using System; +using System.Collections.Generic; + +namespace ILGPU.Algorithms.MatrixOperations +{ + /// + /// A processor for masked matrices to efficiently operate on multiple matrix + /// instances in parallel to maximize occupancy. + /// + public class MaskedMatrixProcessor + : ConcurrentStreamProcessor + where T : unmanaged + where TStride : struct, IStride2D + where TPredicate : struct, InlineList.IPredicate + where TProcessor : struct, IMaskedSparseMatrixProcessor + { + #region Instance + + /// + /// The internal masked matrix multiplier which contains pre-compiled kernels. + /// + private readonly MaskedSparseMatrixMultiplier + matrixMultiplier; + + /// + /// Constructs a new masked processor. + /// + /// The parent accelerator. + /// + /// The maximum number of concurrent streams to use (if any). + /// + /// + /// A custom stream provider function to construct specialized streams. + /// + public MaskedMatrixProcessor( + Accelerator accelerator, + int maxNumConcurrentStreams = 0, + Func streamProvider = null) + : base(accelerator, maxNumConcurrentStreams, streamProvider) + { + matrixMultiplier = accelerator.CreateSparseTransposedMatrixMultiplierMasked< + T, + TPredicate, + TStride, + TProcessor>(); + } + + #endregion + + /// + /// Returns the current predicate to use (if any). + /// + public TPredicate? Predicate { get; set; } + + #region Methods + + /// + /// Multiplies the given matrices using the currently assigned predicate. + /// + /// The current accelerator stream to use. + /// The dense input matrix a of shape MxK. + /// The sparse matrix b of shape NxK (will transpose). + /// A dense output matrix of shape of aView. + public void MultiplyTransposed( + AcceleratorStream stream, + ArrayView2D aView, + SparseMatrixView bView, + ArrayView2D outView) + { + if (!Predicate.HasValue) + throw new InvalidOperationException(); + matrixMultiplier(stream, Predicate.Value, aView, bView, outView); + } + + /// + /// Multiplies the given matrices using the currently assigned predicate. + /// + /// The current accelerator stream to use. + /// The dense input matrices a of shape MxK. + /// + /// The sparse matrices b of shape NxK (will transpose). + /// + /// Dense output matrices of shape of aViews. + public void MultiplyBatchedTransposed( + AcceleratorStream stream, + IReadOnlyList> aViews, + IReadOnlyList> bViews, + IReadOnlyList> outViews) + { + if (aViews.Count != bViews.Count) + throw new ArgumentOutOfRangeException(nameof(bViews)); + if (aViews.Count != outViews.Count) + throw new ArgumentOutOfRangeException(nameof(outViews)); + + ProcessConcurrently(stream, aViews.Count, (acceleratorStream, i) => + MultiplyTransposed(acceleratorStream, aViews[i], bViews[i], outViews[i])); + } + + #endregion + } +}