From 05ab1123e14d2486ece500c3f341fd439f353811 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 17 May 2023 20:15:18 +0200 Subject: [PATCH 1/7] Include `HermOrSym` in character-based `mul!` dispatch --- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 28 +++++++ stdlib/LinearAlgebra/src/adjtrans.jl | 7 +- stdlib/LinearAlgebra/src/matmul.jl | 96 ++++++++++++++++------- stdlib/LinearAlgebra/src/symmetric.jl | 87 +------------------- 4 files changed, 99 insertions(+), 119 deletions(-) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index 5cda4af366814..50d82c497282d 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -457,6 +457,34 @@ const ⋅ = dot const × = cross export ⋅, × +wrapper_char(::AbstractArray) = 'N' +wrapper_char(::Adjoint) = 'C' +wrapper_char(::Adjoint{<:Real}) = 'T' +wrapper_char(::Transpose) = 'T' +wrapper_char(A::Hermitian) = A.uplo == 'U' ? 'H' : 'h' +wrapper_char(A::Hermitian{<:Real}) = A.uplo == 'U' ? 'S' : 's' +wrapper_char(A::Symmetric) = A.uplo == 'U' ? 'S' : 's' + +function wrap(A::AbstractVecOrMat, tA::AbstractChar) + if tA == 'N' + return A + elseif tA == 'T' + return transpose(A) + elseif tA == 'C' + return adjoint(A) + elseif tA == 'H' + return Hermitian(A, :U) + elseif tA == 'h' + return Hermitian(A, :L) + elseif tA == 'S' + return Symmetric(A, :U) + else # tA == 's' + return Symmetric(A, :L) + end +end + +_unwrap(A::AbstractVecOrMat) = A + ## convenience methods ## return only the solution of a least squares problem while avoiding promoting ## vectors to matrices. diff --git a/stdlib/LinearAlgebra/src/adjtrans.jl b/stdlib/LinearAlgebra/src/adjtrans.jl index 2f5c5508e0ee3..79abdd11a206a 100644 --- a/stdlib/LinearAlgebra/src/adjtrans.jl +++ b/stdlib/LinearAlgebra/src/adjtrans.jl @@ -94,11 +94,8 @@ inplace_adj_or_trans(::Type{<:AbstractArray}) = copyto! inplace_adj_or_trans(::Type{<:Adjoint}) = adjoint! inplace_adj_or_trans(::Type{<:Transpose}) = transpose! -adj_or_trans_char(::T) where {T<:AbstractArray} = adj_or_trans_char(T) -adj_or_trans_char(::Type{<:AbstractArray}) = 'N' -adj_or_trans_char(::Type{<:Adjoint}) = 'C' -adj_or_trans_char(::Type{<:Adjoint{<:Real}}) = 'T' -adj_or_trans_char(::Type{<:Transpose}) = 'T' +_unwrap(A::Adjoint) = parent(A) +_unwrap(A::Transpose) = parent(A) Base.dataids(A::Union{Adjoint, Transpose}) = Base.dataids(A.parent) Base.unaliascopy(A::Union{Adjoint,Transpose}) = typeof(A)(Base.unaliascopy(A.parent)) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 170aacee6682f..5479b85e4d58e 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -68,24 +68,24 @@ end @inline mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector, alpha::Number, beta::Number) = - generic_matvecmul!(y, adj_or_trans_char(A), _parent(A), x, MulAddMul(alpha, beta)) + generic_matvecmul!(y, wrapper_char(A), _unwrap(A), x, MulAddMul(alpha, beta)) # BLAS cases # equal eltypes @inline generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} = - gemv!(y, tA, _parent(A), x, _add.alpha, _add.beta) + gemv!(y, tA, A, x, _add.alpha, _add.beta) # Real (possibly transposed) matrix times complex vector. # Multiply the matrix with the real and imaginary parts separately @inline generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}}, _add::MulAddMul=MulAddMul()) where {T<:BlasReal} = - gemv!(y, tA, _parent(A), x, _add.alpha, _add.beta) + gemv!(y, tA, A, x, _add.alpha, _add.beta) # Complex matrix times real vector. # Reinterpret the matrix as a real matrix and do real matvec computation. # works only in cooperation with BLAS when A is untransposed (tA == 'N') # but that check is included in gemv! anyway @inline generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasReal} = - gemv!(y, tA, _parent(A), x, _add.alpha, _add.beta) + gemv!(y, tA, A, x, _add.alpha, _add.beta) # Vector-Matrix multiplication (*)(x::AdjointAbsVec, A::AbstractMatrix) = (A'*x')' @@ -267,10 +267,10 @@ julia> C @inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) = generic_matmatmul!( C, - adj_or_trans_char(A), - adj_or_trans_char(B), - _parent(A), - _parent(B), + wrapper_char(A), + wrapper_char(B), + _unwrap(A), + _unwrap(B), MulAddMul(α, β) ) @@ -340,19 +340,32 @@ julia> lmul!(F.Q, B) """ lmul!(A, B) +# THE one big BLAS dispatch @inline function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, - _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} - if tA == 'T' && tB == 'N' && A === B - return syrk_wrapper!(C, 'T', A, _add) - elseif tA == 'N' && tB == 'T' && A === B - return syrk_wrapper!(C, 'N', A, _add) - elseif tA == 'C' && tB == 'N' && A === B - return herk_wrapper!(C, 'C', A, _add) - elseif tA == 'N' && tB == 'C' && A === B - return herk_wrapper!(C, 'N', A, _add) - else - return gemm_wrapper!(C, tA, tB, A, B, _add) + _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} + alpha, beta = promote(_add.alpha, _add.beta, zero(T)) + if alpha isa Union{Bool,T} && beta isa Union{Bool,T} + if tA == 'T' && tB == 'N' && A === B + return syrk_wrapper!(C, 'T', A, _add) + elseif tA == 'N' && tB == 'T' && A === B + return syrk_wrapper!(C, 'N', A, _add) + elseif tA == 'C' && tB == 'N' && A === B + return herk_wrapper!(C, 'C', A, _add) + elseif tA == 'N' && tB == 'C' && A === B + return herk_wrapper!(C, 'N', A, _add) + elseif tA in ('N', 'T', 'C') && tB in ('N', 'T', 'C') + return gemm_wrapper!(C, tA, tB, A, B, _add) + elseif (tA == 'S' || tA == 's') && tB == 'N' + return BLAS.symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C) + elseif (tB == 'S' || tB == 's') && tA == 'N' + return BLAS.symm!('R', tB == 'S' ? 'U' : 'L', alpha, B, A, beta, C) + elseif (tA == 'H' || tA == 'h') && tB == 'N' + return BLAS.hemm!('L', tA == 'H' ? 'U' : 'L', alpha, A, B, beta, C) + elseif (tB == 'H' || tB == 'h') && tA == 'N' + return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C) + end end + return _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add) end # Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency. @@ -394,8 +407,19 @@ function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x:: alpha, beta = promote(α, β, zero(T)) if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && - !iszero(stride(x, 1)) # We only check input's stride here. - return BLAS.gemv!(tA, alpha, A, x, beta, y) + !iszero(stride(x, 1)) && # We only check input's stride here. + if tA in ('N', 'T', 'C') + return BLAS.gemv!(tA, alpha, A, x, beta, y) + elseif tA in ('S', 's') + return BLAS.symv!(tA == 'S' ? 'U' : 'L', alpha, A, x, beta, y) + elseif tA in ('H', 'h') + return BLAS.hemv!(tA == 'H' ? 'U' : 'L', alpha, A, x, beta, y) + end + end + if tA in ('S', 's', 'H', 'h') + # re-wrap again and use plain ('N') matvec mul algorithm, + # because _generic_matvecmul! can't handle the HermOrSym cases specifically + return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) else return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) end @@ -645,9 +669,14 @@ end # NOTE: the generic version is also called as fallback for # strides != 1 cases -generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, - _add::MulAddMul = MulAddMul()) = - _generic_matvecmul!(C, tA, A, B, _add) +@inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, + _add::MulAddMul = MulAddMul()) + if tA in ('H', 'h', 'S', 's') + return _generic_matvecmul!(C, 'N', wrap(A, tA), B, _add) + else + return _generic_matvecmul!(C, tA, A, B, _add) + end +end function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) @@ -731,21 +760,28 @@ function generic_matmatmul!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::Abs mB, nB = lapack_size(tB, B) mC, nC = size(C) + Anew, ta = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA) + Bnew, tb = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB) + if iszero(_add.alpha) return _rmul_or_fill!(C, _add.beta) end if mA == nA == mB == nB == mC == nC == 2 - return matmul2x2!(C, tA, tB, A, B, _add) + return matmul2x2!(C, ta, tb, Anew, Bnew, _add) end if mA == nA == mB == nB == mC == nC == 3 - return matmul3x3!(C, tA, tB, A, B, _add) + return matmul3x3!(C, ta, tb, Anew, Bnew, _add) end - _generic_matmatmul!(C, tA, tB, A, B, _add) + _generic_matmatmul!(C, ta, tb, Anew, Bnew, _add) end -generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) = - _generic_matmatmul!(C, tA, tB, A, B, _add) - +function generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) + if tA in ('H', 'h', 'S', 's') + return _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add) + else + return _generic_matmatmul!(C, tA, tB, A, B, _add) + end +end function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}, _add::MulAddMul) where {T,S,R} require_one_based_indexing(C, A, B) diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index f96ca812ea0ec..8d9c25515e5c8 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -185,6 +185,9 @@ function hermitian_type(::Type{T}) where {S<:AbstractMatrix, T<:AbstractMatrix{S end hermitian_type(::Type{T}) where {T<:Number} = T +_unwrap(A::Hermitian) = parent(A) +_unwrap(A::Symmetric) = parent(A) + for (S, H) in ((:Symmetric, :Hermitian), (:Hermitian, :Symmetric)) @eval begin $S(A::$S) = A @@ -512,90 +515,6 @@ for f in (:+, :-) end end -## Matvec -@inline function mul!(y::StridedVector{T}, A::Symmetric{T,<:StridedMatrix}, x::StridedVector{T}, - α::Number, β::Number) where {T<:BlasFloat} - alpha, beta = promote(α, β, zero(T)) - if alpha isa Union{Bool,T} && beta isa Union{Bool,T} - return BLAS.symv!(A.uplo, alpha, A.data, x, beta, y) - else - return generic_matvecmul!(y, 'N', A, x, MulAddMul(α, β)) - end -end -@inline function mul!(y::StridedVector{T}, A::Hermitian{T,<:StridedMatrix}, x::StridedVector{T}, - α::Number, β::Number) where {T<:BlasReal} - alpha, beta = promote(α, β, zero(T)) - if alpha isa Union{Bool,T} && beta isa Union{Bool,T} - return BLAS.symv!(A.uplo, alpha, A.data, x, beta, y) - else - return generic_matvecmul!(y, 'N', A, x, MulAddMul(α, β)) - end -end -@inline function mul!(y::StridedVector{T}, A::Hermitian{T,<:StridedMatrix}, x::StridedVector{T}, - α::Number, β::Number) where {T<:BlasComplex} - alpha, beta = promote(α, β, zero(T)) - if alpha isa Union{Bool,T} && beta isa Union{Bool,T} - return BLAS.hemv!(A.uplo, alpha, A.data, x, beta, y) - else - return generic_matvecmul!(y, 'N', A, x, MulAddMul(α, β)) - end -end -## Matmat -@inline function mul!(C::StridedMatrix{T}, A::Symmetric{T,<:StridedMatrix}, B::StridedMatrix{T}, - α::Number, β::Number) where {T<:BlasFloat} - alpha, beta = promote(α, β, zero(T)) - if alpha isa Union{Bool,T} && beta isa Union{Bool,T} - return BLAS.symm!('L', A.uplo, alpha, A.data, B, beta, C) - else - return generic_matmatmul!(C, 'N', 'N', A, B, MulAddMul(alpha, beta)) - end -end -@inline function mul!(C::StridedMatrix{T}, A::StridedMatrix{T}, B::Symmetric{T,<:StridedMatrix}, - α::Number, β::Number) where {T<:BlasFloat} - alpha, beta = promote(α, β, zero(T)) - if alpha isa Union{Bool,T} && beta isa Union{Bool,T} - return BLAS.symm!('R', B.uplo, alpha, B.data, A, beta, C) - else - return generic_matmatmul!(C, 'N', 'N', A, B, MulAddMul(alpha, beta)) - end -end -@inline function mul!(C::StridedMatrix{T}, A::Hermitian{T,<:StridedMatrix}, B::StridedMatrix{T}, - α::Number, β::Number) where {T<:BlasReal} - alpha, beta = promote(α, β, zero(T)) - if alpha isa Union{Bool,T} && beta isa Union{Bool,T} - return BLAS.symm!('L', A.uplo, alpha, A.data, B, beta, C) - else - return generic_matmatmul!(C, 'N', 'N', A, B, MulAddMul(alpha, beta)) - end -end -@inline function mul!(C::StridedMatrix{T}, A::StridedMatrix{T}, B::Hermitian{T,<:StridedMatrix}, - α::Number, β::Number) where {T<:BlasReal} - alpha, beta = promote(α, β, zero(T)) - if alpha isa Union{Bool,T} && beta isa Union{Bool,T} - return BLAS.symm!('R', B.uplo, alpha, B.data, A, beta, C) - else - return generic_matmatmul!(C, 'N', 'N', A, B, MulAddMul(alpha, beta)) - end -end -@inline function mul!(C::StridedMatrix{T}, A::Hermitian{T,<:StridedMatrix}, B::StridedMatrix{T}, - α::Number, β::Number) where {T<:BlasComplex} - alpha, beta = promote(α, β, zero(T)) - if alpha isa Union{Bool,T} && beta isa Union{Bool,T} - return BLAS.hemm!('L', A.uplo, alpha, A.data, B, beta, C) - else - return generic_matmatmul!(C, 'N', 'N', A, B, MulAddMul(alpha, beta)) - end -end -@inline function mul!(C::StridedMatrix{T}, A::StridedMatrix{T}, B::Hermitian{T,<:StridedMatrix}, - α::Number, β::Number) where {T<:BlasComplex} - alpha, beta = promote(α, β, zero(T)) - if alpha isa Union{Bool,T} && beta isa Union{Bool,T} - return BLAS.hemm!('R', B.uplo, alpha, B.data, A, beta, C) - else - return generic_matmatmul!(C, 'N', 'N', A, B, MulAddMul(alpha, beta)) - end -end - *(A::HermOrSym, B::HermOrSym) = A * copyto!(similar(parent(B)), B) function dot(x::AbstractVector, A::RealHermSymComplexHerm, y::AbstractVector) From 2fe6f4d9c19dc332228aadca24f99ec635b7b74a Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Thu, 18 May 2023 17:37:08 +0200 Subject: [PATCH 2/7] fix (may need to remove assertions) --- stdlib/LinearAlgebra/src/matmul.jl | 75 ++++++++++++++++++------------ 1 file changed, 45 insertions(+), 30 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 5479b85e4d58e..a7807f83d1dec 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -343,19 +343,20 @@ lmul!(A, B) # THE one big BLAS dispatch @inline function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} + if tA == 'T' && tB == 'N' && A === B + return syrk_wrapper!(C, 'T', A, _add) + elseif tA == 'N' && tB == 'T' && A === B + return syrk_wrapper!(C, 'N', A, _add) + elseif tA == 'C' && tB == 'N' && A === B + return herk_wrapper!(C, 'C', A, _add) + elseif tA == 'N' && tB == 'C' && A === B + return herk_wrapper!(C, 'N', A, _add) + elseif tA in ('N', 'T', 'C') && tB in ('N', 'T', 'C') + return gemm_wrapper!(C, tA, tB, A, B, _add) + end alpha, beta = promote(_add.alpha, _add.beta, zero(T)) if alpha isa Union{Bool,T} && beta isa Union{Bool,T} - if tA == 'T' && tB == 'N' && A === B - return syrk_wrapper!(C, 'T', A, _add) - elseif tA == 'N' && tB == 'T' && A === B - return syrk_wrapper!(C, 'N', A, _add) - elseif tA == 'C' && tB == 'N' && A === B - return herk_wrapper!(C, 'C', A, _add) - elseif tA == 'N' && tB == 'C' && A === B - return herk_wrapper!(C, 'N', A, _add) - elseif tA in ('N', 'T', 'C') && tB in ('N', 'T', 'C') - return gemm_wrapper!(C, tA, tB, A, B, _add) - elseif (tA == 'S' || tA == 's') && tB == 'N' + if (tA == 'S' || tA == 's') && tB == 'N' return BLAS.symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C) elseif (tB == 'S' || tB == 's') && tA == 'N' return BLAS.symm!('R', tB == 'S' ? 'U' : 'L', alpha, B, A, beta, C) @@ -442,7 +443,8 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y)) return y else - return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) + Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA) + return _generic_matvecmul!(y, ta, Anew, x, MulAddMul(α, β)) end end @@ -464,6 +466,10 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa BLAS.gemv!(tA, alpha, A, xfl[1, :], beta, yfl[1, :]) BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :]) return y + elseif tA in ('S', 's', 'H', 'h') + # re-wrap again and use plain ('N') matvec mul algorithm, + # because _generic_matvecmul! can't handle the HermOrSym cases specifically + return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) else return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) end @@ -576,11 +582,14 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar return _rmul_or_fill!(C, _add.beta) end + Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA) + Bnew, tb = tB in ('S', 's', 'H', 'h') ? (wrap(B, tB), 'N') : (B, tB) + if mA == 2 && nA == 2 && nB == 2 - return matmul2x2!(C, tA, tB, A, B, _add) + return matmul2x2!(C, ta, tb, Anew, Bnew, _add) end if mA == 3 && nA == 3 && nB == 3 - return matmul3x3!(C, tA, tB, A, B, _add) + return matmul3x3!(C, tA, tB, Anew, Bnew, _add) end alpha, beta = promote(_add.alpha, _add.beta, zero(T)) @@ -589,10 +598,11 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 && stride(A, 2) >= size(A, 1) && stride(B, 2) >= size(B, 1) && - stride(C, 2) >= size(C, 1)) + stride(C, 2) >= size(C, 1) && + tA in ('N', 'T', 'C') && tB in ('N', 'T', 'C')) return BLAS.gemm!(tA, tB, alpha, A, B, beta, C) end - _generic_matmatmul!(C, tA, tB, A, B, _add) + _generic_matmatmul!(C, ta, tb, Anew, Bnew, _add) end function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, @@ -616,11 +626,14 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs return _rmul_or_fill!(C, _add.beta) end + Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA) + Bnew, tb = tB in ('S', 's', 'H', 'h') ? (wrap(B, tB), 'N') : (B, tB) + if mA == 2 && nA == 2 && nB == 2 - return matmul2x2!(C, tA, tB, A, B, _add) + return matmul2x2!(C, ta, tb, Anew, Bnew, _add) end if mA == 3 && nA == 3 && nB == 3 - return matmul3x3!(C, tA, tB, A, B, _add) + return matmul3x3!(C, ta, tb, Anew, Bnew, _add) end alpha, beta = promote(_add.alpha, _add.beta, zero(T)) @@ -631,11 +644,11 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 && stride(A, 2) >= size(A, 1) && stride(B, 2) >= size(B, 1) && - stride(C, 2) >= size(C, 1)) && tA == 'N' + stride(C, 2) >= size(C, 1)) && tA == 'N' && tB in ('N', 'T', 'C') BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C)) return C end - _generic_matmatmul!(C, tA, tB, A, B, _add) + _generic_matmatmul!(C, ta, tb, Anew, Bnew, _add) end # blas.jl defines matmul for floats; other integer and mixed precision @@ -671,16 +684,14 @@ end @inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) - if tA in ('H', 'h', 'S', 's') - return _generic_matvecmul!(C, 'N', wrap(A, tA), B, _add) - else - return _generic_matvecmul!(C, tA, A, B, _add) - end + Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA) + return _generic_matvecmul!(C, ta, Anew, B, _add) end function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) require_one_based_indexing(C, A, B) + @assert tA in ('N', 'T', 'C') mB = length(B) mA, nA = lapack_size(tA, A) if mB != nA @@ -776,15 +787,15 @@ function generic_matmatmul!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::Abs end function generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) - if tA in ('H', 'h', 'S', 's') - return _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add) - else - return _generic_matmatmul!(C, tA, tB, A, B, _add) - end + Anew, ta = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA) + Bnew, tb = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB) + return _generic_matmatmul!(C, ta, tb, Anew, Bnew, _add) end function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}, _add::MulAddMul) where {T,S,R} require_one_based_indexing(C, A, B) + @assert tA in ('N', 'C', 'T') + @assert tB in ('N', 'C', 'T') mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) if mB != nA @@ -963,6 +974,8 @@ end function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, _add::MulAddMul = MulAddMul()) require_one_based_indexing(C, A, B) + @assert tA in ('N', 'T', 'C') + @assert tB in ('N', 'T', 'C') if !(size(A) == size(B) == size(C) == (2,2)) throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))")) end @@ -1006,6 +1019,8 @@ end function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, _add::MulAddMul = MulAddMul()) require_one_based_indexing(C, A, B) + @assert tA in ('N', 'T', 'C') + @assert tB in ('N', 'T', 'C') if !(size(A) == size(B) == size(C) == (3,3)) throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))")) end From b7bb7258c7da852b666b7a6197b6504593124b9c Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Thu, 18 May 2023 21:50:01 +0200 Subject: [PATCH 3/7] fix performance issue --- stdlib/LinearAlgebra/src/matmul.jl | 45 +++++++++++++----------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index a7807f83d1dec..f894d207c226c 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -343,16 +343,18 @@ lmul!(A, B) # THE one big BLAS dispatch @inline function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} - if tA == 'T' && tB == 'N' && A === B - return syrk_wrapper!(C, 'T', A, _add) - elseif tA == 'N' && tB == 'T' && A === B - return syrk_wrapper!(C, 'N', A, _add) - elseif tA == 'C' && tB == 'N' && A === B - return herk_wrapper!(C, 'C', A, _add) - elseif tA == 'N' && tB == 'C' && A === B - return herk_wrapper!(C, 'N', A, _add) - elseif tA in ('N', 'T', 'C') && tB in ('N', 'T', 'C') - return gemm_wrapper!(C, tA, tB, A, B, _add) + if all(in(('N', 'T', 'C')), (tA, tB)) + if tA == 'T' && tB == 'N' && A === B + return syrk_wrapper!(C, 'T', A, _add) + elseif tA == 'N' && tB == 'T' && A === B + return syrk_wrapper!(C, 'N', A, _add) + elseif tA == 'C' && tB == 'N' && A === B + return herk_wrapper!(C, 'C', A, _add) + elseif tA == 'N' && tB == 'C' && A === B + return herk_wrapper!(C, 'N', A, _add) + else + return gemm_wrapper!(C, tA, tB, A, B, _add) + end end alpha, beta = promote(_add.alpha, _add.beta, zero(T)) if alpha isa Union{Bool,T} && beta isa Union{Bool,T} @@ -582,14 +584,11 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar return _rmul_or_fill!(C, _add.beta) end - Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA) - Bnew, tb = tB in ('S', 's', 'H', 'h') ? (wrap(B, tB), 'N') : (B, tB) - if mA == 2 && nA == 2 && nB == 2 - return matmul2x2!(C, ta, tb, Anew, Bnew, _add) + return matmul2x2!(C, tA, tB, A, B, _add) end if mA == 3 && nA == 3 && nB == 3 - return matmul3x3!(C, tA, tB, Anew, Bnew, _add) + return matmul3x3!(C, tA, tB, A, B, _add) end alpha, beta = promote(_add.alpha, _add.beta, zero(T)) @@ -598,11 +597,10 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 && stride(A, 2) >= size(A, 1) && stride(B, 2) >= size(B, 1) && - stride(C, 2) >= size(C, 1) && - tA in ('N', 'T', 'C') && tB in ('N', 'T', 'C')) + stride(C, 2) >= size(C, 1)) return BLAS.gemm!(tA, tB, alpha, A, B, beta, C) end - _generic_matmatmul!(C, ta, tb, Anew, Bnew, _add) + _generic_matmatmul!(C, tA, tB, A, B, _add) end function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, @@ -626,14 +624,11 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs return _rmul_or_fill!(C, _add.beta) end - Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA) - Bnew, tb = tB in ('S', 's', 'H', 'h') ? (wrap(B, tB), 'N') : (B, tB) - if mA == 2 && nA == 2 && nB == 2 - return matmul2x2!(C, ta, tb, Anew, Bnew, _add) + return matmul2x2!(C, tA, tB, A, B, _add) end if mA == 3 && nA == 3 && nB == 3 - return matmul3x3!(C, ta, tb, Anew, Bnew, _add) + return matmul3x3!(C, tA, tB, A, B, _add) end alpha, beta = promote(_add.alpha, _add.beta, zero(T)) @@ -644,11 +639,11 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 && stride(A, 2) >= size(A, 1) && stride(B, 2) >= size(B, 1) && - stride(C, 2) >= size(C, 1)) && tA == 'N' && tB in ('N', 'T', 'C') + stride(C, 2) >= size(C, 1) && tA == 'N') BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C)) return C end - _generic_matmatmul!(C, ta, tb, Anew, Bnew, _add) + _generic_matmatmul!(C, tA, tB, A, B, _add) end # blas.jl defines matmul for floats; other integer and mixed precision From 1dd15a309e8f64342c4c21622241e81135c90eb0 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Fri, 19 May 2023 13:58:11 +0200 Subject: [PATCH 4/7] another small fix --- stdlib/LinearAlgebra/src/matmul.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index f894d207c226c..5d57dc0d40ee5 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -374,7 +374,11 @@ end # Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency. @inline function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasReal} - gemm_wrapper!(C, tA, tB, A, B, _add) + if all(in(('N', 'T', 'C')), (tA, tB)) + gemm_wrapper!(C, tA, tB, A, B, _add) + else + _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add) + end end @@ -560,7 +564,11 @@ function gemm_wrapper(tA::AbstractChar, tB::AbstractChar, mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) C = similar(B, T, mA, nB) - gemm_wrapper!(C, tA, tB, A, B) + if all(in(('N', 'T', 'C')), (tA, tB)) + gemm_wrapper!(C, tA, tB, A, B) + else + _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add) + end end function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, From 9fe30f58b9a419d3bae7b36715f4dcac19716f23 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Fri, 19 May 2023 15:58:08 +0200 Subject: [PATCH 5/7] fix real-complex gemv --- stdlib/LinearAlgebra/src/matmul.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 5d57dc0d40ee5..bfbb8b4374365 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -466,7 +466,7 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa alpha, beta = promote(α, β, zero(T)) @views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && - !iszero(stride(x, 1)) + !iszero(stride(x, 1)) && tA in ('N', 'T', 'C') xfl = reinterpret(reshape, T, x) # Use reshape here. yfl = reinterpret(reshape, T, y) BLAS.gemv!(tA, alpha, A, xfl[1, :], beta, yfl[1, :]) From d13d177a926769719d3eecb183722b66836ca3e6 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Mon, 22 May 2023 20:50:56 +0200 Subject: [PATCH 6/7] Handle `HermOrSym` explicitly for small matrices --- stdlib/LinearAlgebra/src/matmul.jl | 122 ++++++++++++++++++++--------- 1 file changed, 83 insertions(+), 39 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index bfbb8b4374365..6f9e39f494d50 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -768,37 +768,29 @@ end const tilebufsize = 10800 # Approximately 32k/3 -function generic_matmatmul!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, - _add::MulAddMul=MulAddMul()) +function generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) mC, nC = size(C) - Anew, ta = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA) - Bnew, tb = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB) - if iszero(_add.alpha) return _rmul_or_fill!(C, _add.beta) end if mA == nA == mB == nB == mC == nC == 2 - return matmul2x2!(C, ta, tb, Anew, Bnew, _add) + return matmul2x2!(C, tA, tB, A, B, _add) end if mA == nA == mB == nB == mC == nC == 3 - return matmul3x3!(C, ta, tb, Anew, Bnew, _add) + return matmul3x3!(C, tA, tB, A, B, _add) end - _generic_matmatmul!(C, ta, tb, Anew, Bnew, _add) + _generic_matmatmul!(C, tA, tB, A, B, _add) end -function generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) - Anew, ta = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA) - Bnew, tb = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB) - return _generic_matmatmul!(C, ta, tb, Anew, Bnew, _add) -end function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}, _add::MulAddMul) where {T,S,R} require_one_based_indexing(C, A, B) - @assert tA in ('N', 'C', 'T') - @assert tB in ('N', 'C', 'T') + A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA) + B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB) + mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) if mB != nA @@ -977,13 +969,13 @@ end function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, _add::MulAddMul = MulAddMul()) require_one_based_indexing(C, A, B) - @assert tA in ('N', 'T', 'C') - @assert tB in ('N', 'T', 'C') if !(size(A) == size(B) == size(C) == (2,2)) throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))")) end @inbounds begin - if tA == 'T' + if tA == 'N' + A11 = A[1,1]; A12 = A[1,2]; A21 = A[2,1]; A22 = A[2,2] + elseif tA == 'T' # TODO making these lazy could improve perf A11 = copy(transpose(A[1,1])); A12 = copy(transpose(A[2,1])) A21 = copy(transpose(A[1,2])); A22 = copy(transpose(A[2,2])) @@ -991,10 +983,23 @@ function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMat # TODO making these lazy could improve perf A11 = copy(A[1,1]'); A12 = copy(A[2,1]') A21 = copy(A[1,2]'); A22 = copy(A[2,2]') - else - A11 = A[1,1]; A12 = A[1,2]; A21 = A[2,1]; A22 = A[2,2] - end - if tB == 'T' + elseif tA == 'S' + A11 = symmetric(A[1,1], :U); A12 = A[1,2] + A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U) + elseif tA == 's' + A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1])) + A21 = A[2,1]; A22 = symmetric(A[2,2], :L) + elseif tA == 'H' + A11 = hermitian(A[1,1], :U); A12 = A[1,2] + A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U) + else # if tA == 'h' + A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1])) + A21 = A[2,1]; A22 = hermitian(A[2,2], :L) + end + if tB == 'N' + B11 = B[1,1]; B12 = B[1,2]; + B21 = B[2,1]; B22 = B[2,2] + elseif tB == 'T' # TODO making these lazy could improve perf B11 = copy(transpose(B[1,1])); B12 = copy(transpose(B[2,1])) B21 = copy(transpose(B[1,2])); B22 = copy(transpose(B[2,2])) @@ -1002,9 +1007,18 @@ function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMat # TODO making these lazy could improve perf B11 = copy(B[1,1]'); B12 = copy(B[2,1]') B21 = copy(B[1,2]'); B22 = copy(B[2,2]') - else - B11 = B[1,1]; B12 = B[1,2]; - B21 = B[2,1]; B22 = B[2,2] + elseif tA == 'S' + B11 = symmetric(A[1,1], :U); B12 = A[1,2] + B21 = copy(transpose(A[1,2])); B22 = symmetric(A[2,2], :U) + elseif tA == 's' + B11 = symmetric(A[1,1], :L); B12 = copy(transpose(A[2,1])) + B21 = A[2,1]; B22 = symmetric(A[2,2], :L) + elseif tA == 'H' + B11 = hermitian(A[1,1], :U); B12 = A[1,2] + B21 = copy(adjoint(A[1,2])); B22 = hermitian(A[2,2], :U) + else # if tA == 'h' + B11 = hermitian(A[1,1], :L); B12 = copy(adjoint(A[2,1])) + B21 = A[2,1]; B22 = hermitian(A[2,2], :L) end _modify!(_add, A11*B11 + A12*B21, C, (1,1)) _modify!(_add, A11*B12 + A12*B22, C, (1,2)) @@ -1022,13 +1036,15 @@ end function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, _add::MulAddMul = MulAddMul()) require_one_based_indexing(C, A, B) - @assert tA in ('N', 'T', 'C') - @assert tB in ('N', 'T', 'C') if !(size(A) == size(B) == size(C) == (3,3)) throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))")) end @inbounds begin - if tA == 'T' + if tA == 'N' + A11 = A[1,1]; A12 = A[1,2]; A13 = A[1,3] + A21 = A[2,1]; A22 = A[2,2]; A23 = A[2,3] + A31 = A[3,1]; A32 = A[3,2]; A33 = A[3,3] + elseif tA == 'T' # TODO making these lazy could improve perf A11 = copy(transpose(A[1,1])); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1])) A21 = copy(transpose(A[1,2])); A22 = copy(transpose(A[2,2])); A23 = copy(transpose(A[3,2])) @@ -1038,13 +1054,29 @@ function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMat A11 = copy(A[1,1]'); A12 = copy(A[2,1]'); A13 = copy(A[3,1]') A21 = copy(A[1,2]'); A22 = copy(A[2,2]'); A23 = copy(A[3,2]') A31 = copy(A[1,3]'); A32 = copy(A[2,3]'); A33 = copy(A[3,3]') - else - A11 = A[1,1]; A12 = A[1,2]; A13 = A[1,3] - A21 = A[2,1]; A22 = A[2,2]; A23 = A[2,3] - A31 = A[3,1]; A32 = A[3,2]; A33 = A[3,3] - end - - if tB == 'T' + elseif tA == 'S' + A11 = symmetric(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3] + A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U); A23 = A[2,3] + A31 = copy(transpose(A[1,3])); A32 = copy(transpose(A[2,3])); A33 = symmetric(A[3,3], :U) + elseif tA == 's' + A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1])) + A21 = A[2,1]; A22 = symmetric(A[2,2], :L); A23 = copy(transpose(A[3,2])) + A31 = A[3,1]; A32 = A[3,2]; A33 = symmetric(A[3,3], :L) + elseif tA == 'H' + A11 = hermitian(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3] + A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U); A23 = A[2,3] + A31 = copy(adjoint(A[1,3])); A32 = copy(adjoint(A[2,3])); A33 = hermitian(A[3,3], :U) + else # if tA == 'h' + A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1])); A13 = copy(adjoint(A[3,1])) + A21 = A[2,1]; A22 = hermitian(A[2,2], :L); A23 = copy(adjoint(A[3,2])) + A31 = A[3,1]; A32 = A[3,2]; A33 = hermitian(A[3,3], :L) + end + + if tB == 'N' + B11 = B[1,1]; B12 = B[1,2]; B13 = B[1,3] + B21 = B[2,1]; B22 = B[2,2]; B23 = B[2,3] + B31 = B[3,1]; B32 = B[3,2]; B33 = B[3,3] + elseif tB == 'T' # TODO making these lazy could improve perf B11 = copy(transpose(B[1,1])); B12 = copy(transpose(B[2,1])); B13 = copy(transpose(B[3,1])) B21 = copy(transpose(B[1,2])); B22 = copy(transpose(B[2,2])); B23 = copy(transpose(B[3,2])) @@ -1054,10 +1086,22 @@ function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMat B11 = copy(B[1,1]'); B12 = copy(B[2,1]'); B13 = copy(B[3,1]') B21 = copy(B[1,2]'); B22 = copy(B[2,2]'); B23 = copy(B[3,2]') B31 = copy(B[1,3]'); B32 = copy(B[2,3]'); B33 = copy(B[3,3]') - else - B11 = B[1,1]; B12 = B[1,2]; B13 = B[1,3] - B21 = B[2,1]; B22 = B[2,2]; B23 = B[2,3] - B31 = B[3,1]; B32 = B[3,2]; B33 = B[3,3] + elseif tB == 'S' + B11 = symmetric(B[1,1], :U); B12 = B[1,2]; B13 = B[1,3] + B21 = copy(transpose(B[1,2])); B22 = symmetric(B[2,2], :U); B23 = B[2,3] + B31 = copy(transpose(B[1,3])); B32 = copy(transpose(B[2,3])); B33 = symmetric(B[3,3], :U) + elseif tB == 's' + B11 = symmetric(B[1,1], :L); B12 = copy(transpose(B[2,1])); B13 = copy(transpose(B[3,1])) + B21 = B[2,1]; B22 = symmetric(B[2,2], :L); B23 = copy(transpose(B[3,2])) + B31 = B[3,1]; B32 = B[3,2]; B33 = symmetric(B[3,3], :L) + elseif tB == 'H' + B11 = hermitian(B[1,1], :U); B12 = B[1,2]; B13 = B[1,3] + B21 = copy(adjoint(B[1,2])); B22 = hermitian(B[2,2], :U); B23 = B[2,3] + B31 = copy(adjoint(B[1,3])); B32 = copy(adjoint(B[2,3])); B33 = hermitian(B[3,3], :U) + else # if tB == 'h' + B11 = hermitian(B[1,1], :L); B12 = copy(adjoint(B[2,1])); B13 = copy(adjoint(B[3,1])) + B21 = B[2,1]; B22 = hermitian(B[2,2], :L); B23 = copy(adjoint(B[3,2])) + B31 = B[3,1]; B32 = B[3,2]; B33 = hermitian(B[3,3], :L) end _modify!(_add, A11*B11 + A12*B21 + A13*B31, C, (1,1)) From a3297e5afdf3971440a0e39716f40309c5589ebe Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Mon, 22 May 2023 21:44:53 +0200 Subject: [PATCH 7/7] fix performance drop in generic mul --- stdlib/LinearAlgebra/src/matmul.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 6f9e39f494d50..e9839857f93e6 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -782,14 +782,15 @@ function generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B: if mA == nA == mB == nB == mC == nC == 3 return matmul3x3!(C, tA, tB, A, B, _add) end + A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA) + B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB) _generic_matmatmul!(C, tA, tB, A, B, _add) end function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}, _add::MulAddMul) where {T,S,R} + @assert tA in ('N', 'T', 'C') && tB in ('N', 'T', 'C') require_one_based_indexing(C, A, B) - A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA) - B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB) mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B)