Skip to content

Commit

Permalink
Simplify mul! dispatch (#49806)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored May 15, 2023
1 parent fbbe9ed commit 15d7bd8
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 77 deletions.
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ inplace_adj_or_trans(::Type{<:Transpose}) = transpose!
adj_or_trans_char(::T) where {T<:AbstractArray} = adj_or_trans_char(T)
adj_or_trans_char(::Type{<:AbstractArray}) = 'N'
adj_or_trans_char(::Type{<:Adjoint}) = 'C'
adj_or_trans_char(::Type{<:Adjoint{<:Real}}) = 'T'
adj_or_trans_char(::Type{<:Transpose}) = 'T'

Base.dataids(A::Union{Adjoint, Transpose}) = Base.dataids(A.parent)
Expand Down
117 changes: 40 additions & 77 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,23 +70,22 @@ end
alpha::Number, beta::Number) =
generic_matvecmul!(y, adj_or_trans_char(A), _parent(A), x, MulAddMul(alpha, beta))
# BLAS cases
@inline mul!(y::StridedVector{T}, A::StridedMaybeAdjOrTransVecOrMat{T}, x::StridedVector{T},
alpha::Number, beta::Number) where {T<:BlasFloat} =
gemv!(y, adj_or_trans_char(A), _parent(A), x, alpha, beta)
# catch the real adjoint case and rewrap to transpose
@inline mul!(y::StridedVector{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, x::StridedVector{T},
alpha::Number, beta::Number) where {T<:BlasReal} =
mul!(y, transpose(adjA.parent), x, alpha, beta)
# equal eltypes
@inline generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasFloat} =
gemv!(y, tA, _parent(A), x, _add.alpha, _add.beta)
# Real (possibly transposed) matrix times complex vector.
# Multiply the matrix with the real and imaginary parts separately
@inline generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}},
_add::MulAddMul=MulAddMul()) where {T<:BlasReal} =
gemv!(y, tA, _parent(A), x, _add.alpha, _add.beta)
# Complex matrix times real vector.
# Reinterpret the matrix as a real matrix and do real matvec computation.
@inline mul!(y::StridedVector{Complex{T}}, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
alpha::Number, beta::Number) where {T<:BlasReal} =
gemv!(y, 'N', A, x, alpha, beta)
# Real matrix times complex vector.
# Multiply the matrix with the real and imaginary parts separately
@inline mul!(y::StridedVector{Complex{T}}, A::StridedMaybeAdjOrTransMat{T}, x::StridedVector{Complex{T}},
alpha::Number, beta::Number) where {T<:BlasReal} =
gemv!(y, A isa StridedArray ? 'N' : 'T', _parent(A), x, alpha, beta)
# works only in cooperation with BLAS when A is untransposed (tA == 'N')
# but that check is included in gemv! anyway
@inline generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasReal} =
gemv!(y, tA, _parent(A), x, _add.alpha, _add.beta)

# Vector-Matrix multiplication
(*)(x::AdjointAbsVec, A::AbstractMatrix) = (A'*x')'
Expand Down Expand Up @@ -341,66 +340,26 @@ julia> lmul!(F.Q, B)
"""
lmul!(A, B)

# generic case
@inline mul!(C::StridedMatrix{T}, A::StridedMaybeAdjOrTransVecOrMat{T}, B::StridedMaybeAdjOrTransVecOrMat{T},
alpha::Number, beta::Number) where {T<:BlasFloat} =
gemm_wrapper!(C, adj_or_trans_char(A), adj_or_trans_char(B), _parent(A), _parent(B), MulAddMul(alpha, beta))

# AtB & ABt (including B === A)
@inline function mul!(C::StridedMatrix{T}, tA::Transpose{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T},
alpha::Number, beta::Number) where {T<:BlasFloat}
A = tA.parent
if A === B
return syrk_wrapper!(C, 'T', A, MulAddMul(alpha, beta))
else
return gemm_wrapper!(C, 'T', 'N', A, B, MulAddMul(alpha, beta))
end
end
@inline function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, tB::Transpose{<:Any,<:StridedVecOrMat{T}},
alpha::Number, beta::Number) where {T<:BlasFloat}
B = tB.parent
if A === B
return syrk_wrapper!(C, 'N', A, MulAddMul(alpha, beta))
else
return gemm_wrapper!(C, 'N', 'T', A, B, MulAddMul(alpha, beta))
end
end
# real adjoint cases, also needed for disambiguation
@inline mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}},
alpha::Number, beta::Number) where {T<:BlasReal} =
mul!(C, A, transpose(adjB.parent), alpha, beta)
@inline mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T},
alpha::Real, beta::Real) where {T<:BlasReal} =
mul!(C, transpose(adjA.parent), B, alpha, beta)

# AcB & ABc (including B === A)
@inline function mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T},
alpha::Number, beta::Number) where {T<:BlasComplex}
A = adjA.parent
if A === B
return herk_wrapper!(C, 'C', A, MulAddMul(alpha, beta))
@inline function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasFloat}
if tA == 'T' && tB == 'N' && A === B
return syrk_wrapper!(C, 'T', A, _add)
elseif tA == 'N' && tB == 'T' && A === B
return syrk_wrapper!(C, 'N', A, _add)
elseif tA == 'C' && tB == 'N' && A === B
return herk_wrapper!(C, 'C', A, _add)
elseif tA == 'N' && tB == 'C' && A === B
return herk_wrapper!(C, 'N', A, _add)
else
return gemm_wrapper!(C, 'C', 'N', A, B, MulAddMul(alpha, beta))
end
end
@inline function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}},
alpha::Number, beta::Number) where {T<:BlasComplex}
B = adjB.parent
if A === B
return herk_wrapper!(C, 'N', A, MulAddMul(alpha, beta))
else
return gemm_wrapper!(C, 'N', 'C', A, B, MulAddMul(alpha, beta))
return gemm_wrapper!(C, tA, tB, A, B, _add)
end
end

# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedMaybeAdjOrTransVecOrMat{Complex{T}}, B::StridedMaybeAdjOrTransVecOrMat{T},
alpha::Number, beta::Number) where {T<:BlasReal} =
gemm_wrapper!(C, adj_or_trans_char(A), adj_or_trans_char(B), _parent(A), _parent(B), MulAddMul(alpha, beta))
# catch the real adjoint case and interpret it as a transpose
@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedVecOrMat{Complex{T}}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}},
alpha::Number, beta::Number) where {T<:BlasReal} =
mul!(C, A, transpose(adjB.parent), alpha, beta)
@inline function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasReal}
gemm_wrapper!(C, tA, tB, A, B, _add)
end


# Supporting functions for matrix multiplication
Expand Down Expand Up @@ -438,7 +397,7 @@ function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x::
!iszero(stride(x, 1)) # We only check input's stride here.
return BLAS.gemv!(tA, alpha, A, x, beta, y)
else
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
end
end

Expand All @@ -459,7 +418,7 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa
BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y))
return y
else
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
end
end

Expand All @@ -482,7 +441,7 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa
BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :])
return y
else
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
end
end

Expand Down Expand Up @@ -609,7 +568,7 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
stride(C, 2) >= size(C, 1))
return BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
end
generic_matmatmul!(C, tA, tB, A, B, _add)
_generic_matmatmul!(C, tA, tB, A, B, _add)
end

function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
Expand Down Expand Up @@ -652,7 +611,7 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs
BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C))
return C
end
generic_matmatmul!(C, tA, tB, A, B, _add)
_generic_matmatmul!(C, tA, tB, A, B, _add)
end

# blas.jl defines matmul for floats; other integer and mixed precision
Expand Down Expand Up @@ -686,8 +645,12 @@ end
# NOTE: the generic version is also called as fallback for
# strides != 1 cases

function generic_matvecmul!(C::AbstractVector{R}, tA, A::AbstractVecOrMat, B::AbstractVector,
_add::MulAddMul = MulAddMul()) where R
generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
_add::MulAddMul = MulAddMul()) =
_generic_matvecmul!(C, tA, A, B, _add)

function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
_add::MulAddMul = MulAddMul())
require_one_based_indexing(C, A, B)
mB = length(B)
mA, nA = lapack_size(tA, A)
Expand Down

0 comments on commit 15d7bd8

Please sign in to comment.