Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
Merge #634
Browse files Browse the repository at this point in the history
634: Provide 5-arg mul! r=maleadt a=haampie

Since Julia 1.3 we have a 5-arg mul! for BLAS, but somehow it has not made its way into CuArrays.

In principle this should not be a breaking change, however, I've commented out `LinearAlgebra.lmul!` because it is not provided by Julia base as far as I know. So that is breaking, but I don't think anybody uses it (?).

Finally it fixes an oversight in `LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, CuMatrix{T}})` where no `<:` was used 😅 


Co-authored-by: Harmen Stoppels <harmenstoppels@gmail.com>
  • Loading branch information
bors[bot] and haampie authored Mar 16, 2020
2 parents 52c4664 + e8648a7 commit 5542a9e
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 32 deletions.
117 changes: 85 additions & 32 deletions src/blas/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,31 @@ function gemv_wrapper!(y::CuVector{T}, tA::Char, A::CuMatrix{T}, x::CuVector{T},
gemv!(tA, alpha, A, x, beta, y)
end

LinearAlgebra.mul!(Y::CuVector{T}, A::CuMatrix{T}, B::CuVector{T}) where T<:CublasFloat = gemv_wrapper!(Y, 'N', A, B)
LinearAlgebra.lmul!(Y::CuVector{T}, A::Transpose{<:Any, CuMatrix{T}}, B::CuVector{T}) where T<:CublasFloat = gemv_wrapper!(Y, 'T', A.parent, B)
LinearAlgebra.lmul!(Y::CuVector{T}, A::Adjoint{<:Any, CuMatrix{T}}, B::CuVector{T}) where T<:CublasFloat = gemv_wrapper!(Y, 'T', A.parent, B)
LinearAlgebra.lmul!(Y::CuVector{T}, A::Adjoint{<:Any, CuMatrix{T}}, B::CuVector{T}) where T<:CublasComplex = gemv_wrapper!(Y, 'C', A.parent, B)
function promote_alpha_beta(a, b, ::Type{T}) where {T}
a_prom, b_prom = promote(a, b, zero(T))
a_prom, b_prom
end

LinearAlgebra.mul!(Y::CuVector{T}, A::CuMatrix{T}, B::CuVector{T}, a::Number, b::Number) where T<:CublasFloat =
gemv_wrapper!(Y, 'N', A, B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(Y::CuVector{T}, A::Transpose{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Number, b::Number) where T<:CublasFloat =
gemv_wrapper!(Y, 'T', A.parent, B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Number, b::Number) where T<:CublasReal =
gemv_wrapper!(Y, 'T', A.parent, B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Number, b::Number) where T<:CublasComplex =
gemv_wrapper!(Y, 'C', A.parent, B, promote_alpha_beta(a, b, T)...)

# Fix Julia 1.3.0 ambiguities... they're fixed in 1.3.1 thanks to https://github.com/JuliaLang/julia/pull/33743
@static if VERSION === v"1.3.0"
LinearAlgebra.mul!(Y::CuVector{T}, A::CuMatrix{T}, B::CuVector{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
gemv_wrapper!(Y, 'N', A, B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(Y::CuVector{T}, A::Transpose{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
gemv_wrapper!(Y, 'T', A.parent, B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
gemv_wrapper!(Y, 'T', A.parent, B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasComplex =
gemv_wrapper!(Y, 'C', A.parent, B, promote_alpha_beta(a, b, T)...)
end

# TRSV

Expand Down Expand Up @@ -156,34 +177,66 @@ function gemm_wrapper!(C::CuVecOrMat{T}, tA::Char, tB::Char,
end

# Mutating
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuVecOrMat{T}, B::CuVecOrMat{T}) where T<:CublasFloat = gemm_wrapper!(C, 'N', 'N', A, B)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}) where T<:CublasFloat =
gemm_wrapper!(C, 'T', 'N', parent(trA), B)
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, trB::Transpose{<:Any, <:CuMatrix{T}}) where T<:CublasFloat =
gemm_wrapper!(C, 'N', 'T', A, parent(trB))
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}) where T<:CublasFloat =
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(trB))
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'N', parent(adjA), B)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}) where T<:CublasFloat =
gemm_wrapper!(C, 'C', 'N', parent(adjA), B)
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}) where T<:CublasReal =
gemm_wrapper!(C, 'N', 'T', A, parent(adjB))
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}) where T<:CublasFloat =
gemm_wrapper!(C, 'N', 'C', A, parent(adjB))
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, CuMatrix{T}}) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(adjB))
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}) where T<:CublasFloat =
gemm_wrapper!(C, 'C', 'C', parent(adjA), parent(adjB))
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{T, <:CuMatrix{T}}) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(adjB))
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}) where T<:CublasFloat =
gemm_wrapper!(C, 'T', 'C', parent(trA), parent(adjB))
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{T, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(trB))
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}) where T <: CublasFloat =
gemm_wrapper!(C, 'C', 'T', parent(adjA), parent(trB))

LinearAlgebra.mul!(C::CuMatrix{T}, A::CuVecOrMat{T}, B::CuVecOrMat{T}, a::Number, b::Number) where T<:CublasFloat =
gemm_wrapper!(C, 'N', 'N', A, B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Number, b::Number) where T<:CublasFloat =
gemm_wrapper!(C, 'T', 'N', parent(trA), B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasFloat =
gemm_wrapper!(C, 'N', 'T', A, parent(trB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasFloat =
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(trB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Number, b::Number) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'N', parent(adjA), B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Number, b::Number) where T<:CublasComplex =
gemm_wrapper!(C, 'C', 'N', parent(adjA), B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasReal =
gemm_wrapper!(C, 'N', 'T', A, parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasComplex =
gemm_wrapper!(C, 'N', 'C', A, parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasComplex =
gemm_wrapper!(C, 'C', 'C', parent(adjA), parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{T, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasComplex =
gemm_wrapper!(C, 'T', 'C', parent(trA), parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{T, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(trB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T <: CublasComplex =
gemm_wrapper!(C, 'C', 'T', parent(adjA), parent(trB), promote_alpha_beta(a, b, T)...)

# Fix Julia 1.3.0 ambiguities... they're fixed in 1.3.1 thanks to https://github.com/JuliaLang/julia/pull/33743
@static if VERSION === v"1.3.0"
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuVecOrMat{T}, B::CuVecOrMat{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
gemm_wrapper!(C, 'N', 'N', A, B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
gemm_wrapper!(C, 'T', 'N', parent(trA), B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
gemm_wrapper!(C, 'N', 'T', A, parent(trB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(trB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'N', parent(adjA), B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasComplex =
gemm_wrapper!(C, 'C', 'N', parent(adjA), B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
gemm_wrapper!(C, 'N', 'T', A, parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasComplex =
gemm_wrapper!(C, 'N', 'C', A, parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasComplex =
gemm_wrapper!(C, 'C', 'C', parent(adjA), parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{T, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasComplex =
gemm_wrapper!(C, 'T', 'C', parent(trA), parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{T, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(trB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T <: CublasComplex =
gemm_wrapper!(C, 'C', 'T', parent(adjA), parent(trB), promote_alpha_beta(a, b, T)...)
end

# TRSM

Expand Down
14 changes: 14 additions & 0 deletions test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ end # level 1 testset
dA = CuArray(A)
@test_throws DimensionMismatch mul!(dy, dA, dx)
end
@testset "mul! y = $f(A) * x * $Ts(a) + y * $Ts(b)" for f in (identity, transpose, adjoint), Ts in (Int, elty)
y, A, x = rand(elty, 5), rand(elty, 5, 5), rand(elty, 5)
dy, dA, dx = CuArray(y), CuArray(A), CuArray(x)
mul!(dy, f(dA), dx, Ts(1), Ts(1))
mul!(y, f(A), x, elty(1), elty(2)) # elty can be replaced with `Ts` on Julia 1.4
@test Array(dy) y
end
@testset "banded methods" begin
# bands
ku = 2
Expand Down Expand Up @@ -399,6 +406,13 @@ end # level 1 testset
end
end
@testset "Level 3" begin
@testset "mul! C = $f(A) * $g(B) * $Ts(a) + C * $Ts(b)" for f in (identity, transpose, adjoint), g in (identity, transpose, adjoint), Ts in (Int, elty)
C, A, B = rand(elty, 5, 5), rand(elty, 5, 5), rand(elty, 5, 5)
dC, dA, dB = CuArray(C), CuArray(A), CuArray(B)
mul!(dC, f(dA), g(dB), Ts(1), Ts(2))
mul!(C, f(A), g(B), elty(1), elty(2)) # elty can be replaced with `Ts` on Julia 1.4
@test Array(dC) C
end
A = rand(elty,m,k)
B = rand(elty,k,n)
Bbad = rand(elty,k+1,n+1)
Expand Down

0 comments on commit 5542a9e

Please sign in to comment.