diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index ad8d8e91af299..c065c692a06ef 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -930,164 +930,143 @@ end # multiply 2x2 matrices -function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S} +Base.@constprop :aggressive function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S} matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B) end -function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, - _add::MulAddMul = MulAddMul()) +function __matmul_checks(C, A, B, sz) require_one_based_indexing(C, A, B) if C === A || B === C throw(ArgumentError("output matrix must not be aliased with input matrix")) end - if !(size(A) == size(B) == size(C) == (2,2)) + if !(size(A) == size(B) == size(C) == sz) throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))")) end + return nothing +end + +# separate function with the core of matmul2x2! that doesn't depend on a MulAddMul +Base.@constprop :aggressive function _matmul2x2_elements(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix) + __matmul_checks(C, A, B, (2,2)) + __matmul2x2_elements(tA, tB, A, B) +end +Base.@constprop :aggressive function __matmul2x2_elements(tA, A::AbstractMatrix) @inbounds begin - if tA == 'N' + tA_uc = uppercase(tA) # possibly unwrap a WrapperChar + if tA_uc == 'N' A11 = A[1,1]; A12 = A[1,2]; A21 = A[2,1]; A22 = A[2,2] - elseif tA == 'T' + elseif tA_uc == 'T' # TODO making these lazy could improve perf A11 = copy(transpose(A[1,1])); A12 = copy(transpose(A[2,1])) A21 = copy(transpose(A[1,2])); A22 = copy(transpose(A[2,2])) - elseif tA == 'C' + elseif tA_uc == 'C' # TODO making these lazy could improve perf A11 = copy(A[1,1]'); A12 = copy(A[2,1]') A21 = copy(A[1,2]'); A22 = copy(A[2,2]') - elseif tA == 'S' - A11 = symmetric(A[1,1], :U); A12 = A[1,2] - A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U) - elseif tA == 's' - A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1])) - A21 = A[2,1]; A22 = symmetric(A[2,2], :L) - elseif tA == 'H' - A11 = hermitian(A[1,1], :U); A12 = A[1,2] - A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U) - else # if tA == 'h' - A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1])) - A21 = A[2,1]; A22 = hermitian(A[2,2], :L) - end - if tB == 'N' - B11 = B[1,1]; B12 = B[1,2]; - B21 = B[2,1]; B22 = B[2,2] - elseif tB == 'T' - # TODO making these lazy could improve perf - B11 = copy(transpose(B[1,1])); B12 = copy(transpose(B[2,1])) - B21 = copy(transpose(B[1,2])); B22 = copy(transpose(B[2,2])) - elseif tB == 'C' - # TODO making these lazy could improve perf - B11 = copy(B[1,1]'); B12 = copy(B[2,1]') - B21 = copy(B[1,2]'); B22 = copy(B[2,2]') - elseif tB == 'S' - B11 = symmetric(B[1,1], :U); B12 = B[1,2] - B21 = copy(transpose(B[1,2])); B22 = symmetric(B[2,2], :U) - elseif tB == 's' - B11 = symmetric(B[1,1], :L); B12 = copy(transpose(B[2,1])) - B21 = B[2,1]; B22 = symmetric(B[2,2], :L) - elseif tB == 'H' - B11 = hermitian(B[1,1], :U); B12 = B[1,2] - B21 = copy(adjoint(B[1,2])); B22 = hermitian(B[2,2], :U) - else # if tB == 'h' - B11 = hermitian(B[1,1], :L); B12 = copy(adjoint(B[2,1])) - B21 = B[2,1]; B22 = hermitian(B[2,2], :L) + elseif tA_uc == 'S' + if isuppercase(tA) # tA == 'S' + A11 = symmetric(A[1,1], :U); A12 = A[1,2] + A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U) + else + A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1])) + A21 = A[2,1]; A22 = symmetric(A[2,2], :L) + end + elseif tA_uc == 'H' + if isuppercase(tA) # tA == 'H' + A11 = hermitian(A[1,1], :U); A12 = A[1,2] + A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U) + else # if tA == 'h' + A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1])) + A21 = A[2,1]; A22 = hermitian(A[2,2], :L) + end end + end # inbounds + A11, A12, A21, A22 +end +Base.@constprop :aggressive __matmul2x2_elements(tA, tB, A, B) = __matmul2x2_elements(tA, A), __matmul2x2_elements(tB, B) + +Base.@constprop :aggressive function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, + _add::MulAddMul = MulAddMul()) + (A11, A12, A21, A22), (B11, B12, B21, B22) = _matmul2x2_elements(C, tA, tB, A, B) + @inbounds begin _modify!(_add, A11*B11 + A12*B21, C, (1,1)) - _modify!(_add, A11*B12 + A12*B22, C, (1,2)) _modify!(_add, A21*B11 + A22*B21, C, (2,1)) + _modify!(_add, A11*B12 + A12*B22, C, (1,2)) _modify!(_add, A21*B12 + A22*B22, C, (2,2)) end # inbounds C end # Multiply 3x3 matrices -function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S} +Base.@constprop :aggressive function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S} matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B) end -function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, - _add::MulAddMul = MulAddMul()) - require_one_based_indexing(C, A, B) - if C === A || B === C - throw(ArgumentError("output matrix must not be aliased with input matrix")) - end - if !(size(A) == size(B) == size(C) == (3,3)) - throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))")) - end +# separate function with the core of matmul3x3! that doesn't depend on a MulAddMul +Base.@constprop :aggressive function _matmul3x3_elements(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix) + __matmul_checks(C, A, B, (3,3)) + __matmul3x3_elements(tA, tB, A, B) +end +Base.@constprop :aggressive function __matmul3x3_elements(tA, A::AbstractMatrix) @inbounds begin - if tA == 'N' + tA_uc = uppercase(tA) # possibly unwrap a WrapperChar + if tA_uc == 'N' A11 = A[1,1]; A12 = A[1,2]; A13 = A[1,3] A21 = A[2,1]; A22 = A[2,2]; A23 = A[2,3] A31 = A[3,1]; A32 = A[3,2]; A33 = A[3,3] - elseif tA == 'T' + elseif tA_uc == 'T' # TODO making these lazy could improve perf A11 = copy(transpose(A[1,1])); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1])) A21 = copy(transpose(A[1,2])); A22 = copy(transpose(A[2,2])); A23 = copy(transpose(A[3,2])) A31 = copy(transpose(A[1,3])); A32 = copy(transpose(A[2,3])); A33 = copy(transpose(A[3,3])) - elseif tA == 'C' + elseif tA_uc == 'C' # TODO making these lazy could improve perf A11 = copy(A[1,1]'); A12 = copy(A[2,1]'); A13 = copy(A[3,1]') A21 = copy(A[1,2]'); A22 = copy(A[2,2]'); A23 = copy(A[3,2]') A31 = copy(A[1,3]'); A32 = copy(A[2,3]'); A33 = copy(A[3,3]') - elseif tA == 'S' - A11 = symmetric(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3] - A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U); A23 = A[2,3] - A31 = copy(transpose(A[1,3])); A32 = copy(transpose(A[2,3])); A33 = symmetric(A[3,3], :U) - elseif tA == 's' - A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1])) - A21 = A[2,1]; A22 = symmetric(A[2,2], :L); A23 = copy(transpose(A[3,2])) - A31 = A[3,1]; A32 = A[3,2]; A33 = symmetric(A[3,3], :L) - elseif tA == 'H' - A11 = hermitian(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3] - A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U); A23 = A[2,3] - A31 = copy(adjoint(A[1,3])); A32 = copy(adjoint(A[2,3])); A33 = hermitian(A[3,3], :U) - else # if tA == 'h' - A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1])); A13 = copy(adjoint(A[3,1])) - A21 = A[2,1]; A22 = hermitian(A[2,2], :L); A23 = copy(adjoint(A[3,2])) - A31 = A[3,1]; A32 = A[3,2]; A33 = hermitian(A[3,3], :L) + elseif tA_uc == 'S' + if isuppercase(tA) # tA == 'S' + A11 = symmetric(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3] + A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U); A23 = A[2,3] + A31 = copy(transpose(A[1,3])); A32 = copy(transpose(A[2,3])); A33 = symmetric(A[3,3], :U) + else + A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1])) + A21 = A[2,1]; A22 = symmetric(A[2,2], :L); A23 = copy(transpose(A[3,2])) + A31 = A[3,1]; A32 = A[3,2]; A33 = symmetric(A[3,3], :L) + end + elseif tA_uc == 'H' + if isuppercase(tA) # tA == 'H' + A11 = hermitian(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3] + A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U); A23 = A[2,3] + A31 = copy(adjoint(A[1,3])); A32 = copy(adjoint(A[2,3])); A33 = hermitian(A[3,3], :U) + else # if tA == 'h' + A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1])); A13 = copy(adjoint(A[3,1])) + A21 = A[2,1]; A22 = hermitian(A[2,2], :L); A23 = copy(adjoint(A[3,2])) + A31 = A[3,1]; A32 = A[3,2]; A33 = hermitian(A[3,3], :L) + end end + end # inbounds + A11, A12, A13, A21, A22, A23, A31, A32, A33 +end +Base.@constprop :aggressive __matmul3x3_elements(tA, tB, A, B) = __matmul3x3_elements(tA, A), __matmul3x3_elements(tB, B) - if tB == 'N' - B11 = B[1,1]; B12 = B[1,2]; B13 = B[1,3] - B21 = B[2,1]; B22 = B[2,2]; B23 = B[2,3] - B31 = B[3,1]; B32 = B[3,2]; B33 = B[3,3] - elseif tB == 'T' - # TODO making these lazy could improve perf - B11 = copy(transpose(B[1,1])); B12 = copy(transpose(B[2,1])); B13 = copy(transpose(B[3,1])) - B21 = copy(transpose(B[1,2])); B22 = copy(transpose(B[2,2])); B23 = copy(transpose(B[3,2])) - B31 = copy(transpose(B[1,3])); B32 = copy(transpose(B[2,3])); B33 = copy(transpose(B[3,3])) - elseif tB == 'C' - # TODO making these lazy could improve perf - B11 = copy(B[1,1]'); B12 = copy(B[2,1]'); B13 = copy(B[3,1]') - B21 = copy(B[1,2]'); B22 = copy(B[2,2]'); B23 = copy(B[3,2]') - B31 = copy(B[1,3]'); B32 = copy(B[2,3]'); B33 = copy(B[3,3]') - elseif tB == 'S' - B11 = symmetric(B[1,1], :U); B12 = B[1,2]; B13 = B[1,3] - B21 = copy(transpose(B[1,2])); B22 = symmetric(B[2,2], :U); B23 = B[2,3] - B31 = copy(transpose(B[1,3])); B32 = copy(transpose(B[2,3])); B33 = symmetric(B[3,3], :U) - elseif tB == 's' - B11 = symmetric(B[1,1], :L); B12 = copy(transpose(B[2,1])); B13 = copy(transpose(B[3,1])) - B21 = B[2,1]; B22 = symmetric(B[2,2], :L); B23 = copy(transpose(B[3,2])) - B31 = B[3,1]; B32 = B[3,2]; B33 = symmetric(B[3,3], :L) - elseif tB == 'H' - B11 = hermitian(B[1,1], :U); B12 = B[1,2]; B13 = B[1,3] - B21 = copy(adjoint(B[1,2])); B22 = hermitian(B[2,2], :U); B23 = B[2,3] - B31 = copy(adjoint(B[1,3])); B32 = copy(adjoint(B[2,3])); B33 = hermitian(B[3,3], :U) - else # if tB == 'h' - B11 = hermitian(B[1,1], :L); B12 = copy(adjoint(B[2,1])); B13 = copy(adjoint(B[3,1])) - B21 = B[2,1]; B22 = hermitian(B[2,2], :L); B23 = copy(adjoint(B[3,2])) - B31 = B[3,1]; B32 = B[3,2]; B33 = hermitian(B[3,3], :L) - end +Base.@constprop :aggressive function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, + _add::MulAddMul = MulAddMul()) - _modify!(_add, A11*B11 + A12*B21 + A13*B31, C, (1,1)) - _modify!(_add, A11*B12 + A12*B22 + A13*B32, C, (1,2)) - _modify!(_add, A11*B13 + A12*B23 + A13*B33, C, (1,3)) + (A11, A12, A13, A21, A22, A23, A31, A32, A33), + (B11, B12, B13, B21, B22, B23, B31, B32, B33) = _matmul3x3_elements(C, tA, tB, A, B) + @inbounds begin + _modify!(_add, A11*B11 + A12*B21 + A13*B31, C, (1,1)) _modify!(_add, A21*B11 + A22*B21 + A23*B31, C, (2,1)) - _modify!(_add, A21*B12 + A22*B22 + A23*B32, C, (2,2)) - _modify!(_add, A21*B13 + A22*B23 + A23*B33, C, (2,3)) - _modify!(_add, A31*B11 + A32*B21 + A33*B31, C, (3,1)) + + _modify!(_add, A11*B12 + A12*B22 + A13*B32, C, (1,2)) + _modify!(_add, A21*B12 + A22*B22 + A23*B32, C, (2,2)) _modify!(_add, A31*B12 + A32*B22 + A33*B32, C, (3,2)) + + _modify!(_add, A11*B13 + A12*B23 + A13*B33, C, (1,3)) + _modify!(_add, A21*B13 + A22*B23 + A23*B33, C, (2,3)) _modify!(_add, A31*B13 + A32*B23 + A33*B33, C, (3,3)) end # inbounds C