From b54dce216bf5449f59b4de12a6ff0b7d7d17bd95 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sat, 25 May 2024 14:21:24 +0530 Subject: [PATCH] Split generic_matmul for strided matrices into two halves (#54552) --- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 4 ++ stdlib/LinearAlgebra/src/matmul.jl | 78 +++++++++++++++-------- 2 files changed, 54 insertions(+), 28 deletions(-) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index 5663f66eac8b3..e2ad7873ed834 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -576,6 +576,10 @@ wrapper_char(A::Hermitian) = WrapperChar('H', A.uplo == 'U') wrapper_char(A::Hermitian{<:Real}) = WrapperChar('S', A.uplo == 'U') wrapper_char(A::Symmetric) = WrapperChar('S', A.uplo == 'U') +wrapper_char_NTC(A::AbstractArray) = uppercase(wrapper_char(A)) == 'N' +wrapper_char_NTC(A::Union{StridedArray, Adjoint, Transpose}) = true +wrapper_char_NTC(A::Union{Symmetric, Hermitian}) = false + Base.@constprop :aggressive function wrap(A::AbstractVecOrMat, tA::AbstractChar) # merge the result of this before return, so that we can type-assert the return such # that even if the tmerge is inaccurate, inference can still identify that the diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 3e7c0a0ecc862..d745122c3700a 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -293,15 +293,24 @@ true @inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) = _mul!(C, A, B, α, β) # Add a level of indirection and specialize _mul! to avoid ambiguities in mul! @inline _mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) = - generic_matmatmul!( + generic_matmatmul_wrapper!( C, wrapper_char(A), wrapper_char(B), _unwrap(A), _unwrap(B), - α, β + α, β, + Val(wrapper_char_NTC(A) & wrapper_char_NTC(B)) ) +# this indirection allows is to specialize on the types of the wrappers of A and B to some extent, +# even though the wrappers are stripped off in mul! +# By default, we ignore the wrapper info and forward the arguments to generic_matmatmul! +Base.@constprop :aggressive function generic_matmatmul_wrapper!(C, tA, tB, A, B, α, β, @nospecialize(val)) + generic_matmatmul!(C, tA, tB, A, B, α, β) +end + + """ rmul!(A, B) @@ -368,9 +377,9 @@ julia> lmul!(F.Q, B) """ lmul!(A, B) -# THE one big BLAS dispatch -Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, - α::Number, β::Number) where {T<:BlasFloat} +# THE one big BLAS dispatch. This is split into two methods to improve latency +Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, + α::Number, β::Number, ::Val{true}) where {T<:BlasFloat} mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α) @@ -389,19 +398,37 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, # and extract the char corresponding to the wrapper type tA_uc, tB_uc = uppercase(tA), uppercase(tB) # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them. - if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc))) - if tA_uc == 'T' && tB_uc == 'N' && A === B - return syrk_wrapper!(C, 'T', A, α, β) - elseif tA_uc == 'N' && tB_uc == 'T' && A === B - return syrk_wrapper!(C, 'N', A, α, β) - elseif tA_uc == 'C' && tB_uc == 'N' && A === B - return herk_wrapper!(C, 'C', A, α, β) - elseif tA_uc == 'N' && tB_uc == 'C' && A === B - return herk_wrapper!(C, 'N', A, α, β) - else - return gemm_wrapper!(C, tA, tB, A, B, α, β) + if tA_uc == 'T' && tB_uc == 'N' && A === B + return syrk_wrapper!(C, 'T', A, α, β) + elseif tA_uc == 'N' && tB_uc == 'T' && A === B + return syrk_wrapper!(C, 'N', A, α, β) + elseif tA_uc == 'C' && tB_uc == 'N' && A === B + return herk_wrapper!(C, 'C', A, α, β) + elseif tA_uc == 'N' && tB_uc == 'C' && A === B + return herk_wrapper!(C, 'N', A, α, β) + else + return gemm_wrapper!(C, tA, tB, A, B, α, β) + end +end +Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, + α::Number, β::Number, ::Val{false}) where {T<:BlasFloat} + mA, nA = lapack_size(tA, A) + mB, nB = lapack_size(tB, B) + if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α) + if size(C) != (mA, nB) + throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)")) end + return _rmul_or_fill!(C, β) + end + if size(C) == size(A) == size(B) == (2,2) + return matmul2x2!(C, tA, tB, A, B, α, β) + end + if size(C) == size(A) == size(B) == (3,3) + return matmul3x3!(C, tA, tB, A, B, α, β) end + # We convert the chars to uppercase to potentially unwrap a WrapperChar, + # and extract the char corresponding to the wrapper type + tA_uc, tB_uc = uppercase(tA), uppercase(tB) alpha, beta = promote(α, β, zero(T)) if alpha isa Union{Bool,T} && beta isa Union{Bool,T} if tA_uc == 'S' && tB_uc == 'N' @@ -421,18 +448,13 @@ Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::S _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) -# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency. -Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, - α::Number, β::Number) where {T<:BlasReal} - # We convert the chars to uppercase to potentially unwrap a WrapperChar, - # and extract the char corresponding to the wrapper type - tA_uc, tB_uc = uppercase(tA), uppercase(tB) - # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them. - if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc))) - gemm_wrapper!(C, tA, tB, A, B, α, β) - else - _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β)) - end +function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, + α::Number, β::Number, ::Val{true}) where {T<:BlasReal} + gemm_wrapper!(C, tA, tB, A, B, α, β) +end +Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, + α::Number, β::Number, ::Val{false}) where {T<:BlasReal} + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β)) end # legacy method Base.@constprop :aggressive generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},