Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
Try #641:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Mar 19, 2020
2 parents 15e6ef6 + 5c33ad3 commit 96028ea
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 36 deletions.
10 changes: 6 additions & 4 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ version = "4.0.0"

[[CUDAdrv]]
deps = ["CEnum", "CUDAapi", "Printf"]
git-tree-sha1 = "f176b994a6e4c70aafa626cbf825aa9c34adc9e6"
git-tree-sha1 = "9db0ff78ac601ca66c85f3bededdbacc3d1c9bdb"
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
version = "6.1.0"
version = "6.2.0"

[[CUDAnative]]
deps = ["Adapt", "BinaryProvider", "CEnum", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "MacroTools", "Pkg", "Printf", "TimerOutputs"]
Expand All @@ -62,9 +62,11 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[GPUArrays]]
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
git-tree-sha1 = "b385075caff384494fdda11300755d667b28b333"
git-tree-sha1 = "e68ff0162eec49362685a6db1a543547b0eb101f"
repo-rev = "b1be744d4306dded35fbb055cea20c90291a7d0f"
repo-url = "https://github.com/JuliaGPU/GPUArrays.jl.git"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "3.0.0"
version = "3.0.1"

[[InteractiveUtils]]
deps = ["Markdown"]
Expand Down
117 changes: 85 additions & 32 deletions src/blas/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,31 @@ function gemv_wrapper!(y::CuVector{T}, tA::Char, A::CuMatrix{T}, x::CuVector{T},
gemv!(tA, alpha, A, x, beta, y)
end

LinearAlgebra.mul!(Y::CuVector{T}, A::CuMatrix{T}, B::CuVector{T}) where T<:CublasFloat = gemv_wrapper!(Y, 'N', A, B)
LinearAlgebra.lmul!(Y::CuVector{T}, A::Transpose{<:Any, CuMatrix{T}}, B::CuVector{T}) where T<:CublasFloat = gemv_wrapper!(Y, 'T', A.parent, B)
LinearAlgebra.lmul!(Y::CuVector{T}, A::Adjoint{<:Any, CuMatrix{T}}, B::CuVector{T}) where T<:CublasFloat = gemv_wrapper!(Y, 'T', A.parent, B)
LinearAlgebra.lmul!(Y::CuVector{T}, A::Adjoint{<:Any, CuMatrix{T}}, B::CuVector{T}) where T<:CublasComplex = gemv_wrapper!(Y, 'C', A.parent, B)
function promote_alpha_beta(a, b, ::Type{T}) where {T}
a_prom, b_prom = promote(a, b, zero(T))
a_prom, b_prom
end

LinearAlgebra.mul!(Y::CuVector{T}, A::CuMatrix{T}, B::CuVector{T}, a::Number, b::Number) where T<:CublasFloat =
gemv_wrapper!(Y, 'N', A, B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(Y::CuVector{T}, A::Transpose{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Number, b::Number) where T<:CublasFloat =
gemv_wrapper!(Y, 'T', A.parent, B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Number, b::Number) where T<:CublasReal =
gemv_wrapper!(Y, 'T', A.parent, B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Number, b::Number) where T<:CublasComplex =
gemv_wrapper!(Y, 'C', A.parent, B, promote_alpha_beta(a, b, T)...)

# Fix Julia <= 1.3.1 ambiguities... they're fixed in 1.4.x thanks to https://github.com/JuliaLang/julia/pull/33743
@static if v"1.3.0" <= VERSION <= v"1.3.1"
LinearAlgebra.mul!(Y::CuVector{T}, A::CuMatrix{T}, B::CuVector{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
gemv_wrapper!(Y, 'N', A, B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(Y::CuVector{T}, A::Transpose{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
gemv_wrapper!(Y, 'T', A.parent, B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
gemv_wrapper!(Y, 'T', A.parent, B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuMatrix{T}}, B::CuVector{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasComplex =
gemv_wrapper!(Y, 'C', A.parent, B, promote_alpha_beta(a, b, T)...)
end

# TRSV

Expand Down Expand Up @@ -156,34 +177,66 @@ function gemm_wrapper!(C::CuVecOrMat{T}, tA::Char, tB::Char,
end

# Mutating
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuVecOrMat{T}, B::CuVecOrMat{T}) where T<:CublasFloat = gemm_wrapper!(C, 'N', 'N', A, B)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}) where T<:CublasFloat =
gemm_wrapper!(C, 'T', 'N', parent(trA), B)
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, trB::Transpose{<:Any, <:CuMatrix{T}}) where T<:CublasFloat =
gemm_wrapper!(C, 'N', 'T', A, parent(trB))
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}) where T<:CublasFloat =
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(trB))
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'N', parent(adjA), B)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}) where T<:CublasFloat =
gemm_wrapper!(C, 'C', 'N', parent(adjA), B)
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}) where T<:CublasReal =
gemm_wrapper!(C, 'N', 'T', A, parent(adjB))
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}) where T<:CublasFloat =
gemm_wrapper!(C, 'N', 'C', A, parent(adjB))
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, CuMatrix{T}}) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(adjB))
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}) where T<:CublasFloat =
gemm_wrapper!(C, 'C', 'C', parent(adjA), parent(adjB))
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{T, <:CuMatrix{T}}) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(adjB))
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}) where T<:CublasFloat =
gemm_wrapper!(C, 'T', 'C', parent(trA), parent(adjB))
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{T, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(trB))
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}) where T <: CublasFloat =
gemm_wrapper!(C, 'C', 'T', parent(adjA), parent(trB))

LinearAlgebra.mul!(C::CuMatrix{T}, A::CuVecOrMat{T}, B::CuVecOrMat{T}, a::Number, b::Number) where T<:CublasFloat =
gemm_wrapper!(C, 'N', 'N', A, B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Number, b::Number) where T<:CublasFloat =
gemm_wrapper!(C, 'T', 'N', parent(trA), B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasFloat =
gemm_wrapper!(C, 'N', 'T', A, parent(trB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasFloat =
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(trB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Number, b::Number) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'N', parent(adjA), B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Number, b::Number) where T<:CublasComplex =
gemm_wrapper!(C, 'C', 'N', parent(adjA), B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasReal =
gemm_wrapper!(C, 'N', 'T', A, parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasComplex =
gemm_wrapper!(C, 'N', 'C', A, parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasComplex =
gemm_wrapper!(C, 'C', 'C', parent(adjA), parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{T, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasComplex =
gemm_wrapper!(C, 'T', 'C', parent(trA), parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{T, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(trB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Number, b::Number) where T <: CublasComplex =
gemm_wrapper!(C, 'C', 'T', parent(adjA), parent(trB), promote_alpha_beta(a, b, T)...)

# Fix Julia <= 1.3.1 ambiguities... they're fixed in 1.4.x thanks to https://github.com/JuliaLang/julia/pull/33743
@static if v"1.3.0" <= VERSION <= v"1.3.1"
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuVecOrMat{T}, B::CuVecOrMat{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
gemm_wrapper!(C, 'N', 'N', A, B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
gemm_wrapper!(C, 'T', 'N', parent(trA), B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
gemm_wrapper!(C, 'N', 'T', A, parent(trB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasFloat =
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(trB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'N', parent(adjA), B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, B::CuMatrix{T}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasComplex =
gemm_wrapper!(C, 'C', 'N', parent(adjA), B, promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
gemm_wrapper!(C, 'N', 'T', A, parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasComplex =
gemm_wrapper!(C, 'N', 'C', A, parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasComplex =
gemm_wrapper!(C, 'C', 'C', parent(adjA), parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{T, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, trA::Transpose{<:Any, <:CuMatrix{T}}, adjB::Adjoint{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasComplex =
gemm_wrapper!(C, 'T', 'C', parent(trA), parent(adjB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{T, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T<:CublasReal =
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(trB), promote_alpha_beta(a, b, T)...)
LinearAlgebra.mul!(C::CuMatrix{T}, adjA::Adjoint{<:Any, <:CuMatrix{T}}, trB::Transpose{<:Any, <:CuMatrix{T}}, a::Union{T,Bool}, b::Union{T,Bool}) where T <: CublasComplex =
gemm_wrapper!(C, 'C', 'T', parent(adjA), parent(trB), promote_alpha_beta(a, b, T)...)
end

# TRSM

Expand Down
14 changes: 14 additions & 0 deletions test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ end # level 1 testset
dA = CuArray(A)
@test_throws DimensionMismatch mul!(dy, dA, dx)
end
@testset "mul! y = $f(A) * x * $Ts(a) + y * $Ts(b)" for f in (identity, transpose, adjoint), Ts in (Int, elty)
y, A, x = rand(elty, 5), rand(elty, 5, 5), rand(elty, 5)
dy, dA, dx = CuArray(y), CuArray(A), CuArray(x)
mul!(dy, f(dA), dx, Ts(1), Ts(1))
mul!(y, f(A), x, elty(1), elty(2)) # elty can be replaced with `Ts` on Julia 1.4
@test Array(dy) y
end
@testset "banded methods" begin
# bands
ku = 2
Expand Down Expand Up @@ -399,6 +406,13 @@ end # level 1 testset
end
end
@testset "Level 3" begin
@testset "mul! C = $f(A) * $g(B) * $Ts(a) + C * $Ts(b)" for f in (identity, transpose, adjoint), g in (identity, transpose, adjoint), Ts in (Int, elty)
C, A, B = rand(elty, 5, 5), rand(elty, 5, 5), rand(elty, 5, 5)
dC, dA, dB = CuArray(C), CuArray(A), CuArray(B)
mul!(dC, f(dA), g(dB), Ts(1), Ts(2))
mul!(C, f(A), g(B), elty(1), elty(2)) # elty can be replaced with `Ts` on Julia 1.4
@test Array(dC) C
end
A = rand(elty,m,k)
B = rand(elty,k,n)
Bbad = rand(elty,k+1,n+1)
Expand Down

0 comments on commit 96028ea

Please sign in to comment.