Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Add direct BLAS calls for trmm! and trsm! #707

Merged
merged 3 commits into from
May 6, 2020
Merged

Conversation

aterenin
Copy link
Contributor

@aterenin aterenin commented May 5, 2020

This allows things like UpperTriangular which call these methods internally to work correctly with CuArrays.

Note that the @eval is needed to avoid method ambiguity. If there's a cleaner way to iterate over subtypes of CublasFloat I'm happy to do that.

@vchuravy
Copy link
Member

vchuravy commented May 5, 2020

function LinearAlgebra.BLAS.trmm!(side::AbstractChar, uplo::AbstractChar, transa::AbstractChar, diag::AbstractChar, alpha::$T, A::CuMatrix{$T}, B::CuMatrix{$T}

is equivalent to:

function LinearAlgebra.BLAS.trmm!(side::AbstractChar, uplo::AbstractChar, transa::AbstractChar, diag::AbstractChar, alpha::T, A::CuMatrix{T}, B::CuMatrix{T}) where T <: CublasFloat

Would be good to add a test as well.

@aterenin
Copy link
Contributor Author

aterenin commented May 5, 2020

function LinearAlgebra.BLAS.trmm!(side::AbstractChar, uplo::AbstractChar, transa::AbstractChar, diag::AbstractChar, alpha::$T, A::CuMatrix{$T}, B::CuMatrix{$T}

is equivalent to:

function LinearAlgebra.BLAS.trmm!(side::AbstractChar, uplo::AbstractChar, transa::AbstractChar, diag::AbstractChar, alpha::T, A::CuMatrix{T}, B::CuMatrix{T}) where T <: CublasFloat

Would be good to add a test as well.

Unfortunately, that results in an ambiguous type error. I'll add some tests tomorrow.

@vchuravy
Copy link
Member

vchuravy commented May 5, 2020

What are the ambiguities? There should be no difference between that and manually instantiating it.

@aterenin
Copy link
Contributor Author

aterenin commented May 5, 2020

What are the ambiguities? There should be no difference between that and manually instantiating it.

ERROR: MethodError: trsm!(::Char, ::Char, ::Char, ::Char, ::Float32, ::CuArrays.CuArray{Float32,2,Nothing}, ::CuArrays.CuArray{Float32,2,Nothing}) is ambiguous. Candidates:
  trsm!(side::AbstractChar, uplo::AbstractChar, transa::AbstractChar, diag::AbstractChar, alpha::Float32, A::AbstractArray{Float32,2}, B::AbstractArray{Float32,2}) in LinearAlgebra.BLAS at /vol/bitbucket/at6617/julia-1.4.1/share/julia/stdlib/v1.4/LinearAlgebra/src/blas.jl:1690
  trsm!(side::AbstractChar, uplo::AbstractChar, transa::AbstractChar, diag::AbstractChar, alpha::T, A::CuArrays.CuArray{T,2,P} where P, B::CuArrays.CuArray{T,2,P} where P) where T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64} in SparseGaussianProcesses at /homes/at6617/SparseGaussianProcesses.jl/src/gpu.jl:50
Possible fix, define
  trsm!(::AbstractChar, ::AbstractChar, ::AbstractChar, ::AbstractChar, ::Float32, ::CuArrays.CuArray{Float32,2,P} where P, ::CuArrays.CuArray{Float32,2,P} where P)

It's an ambiguity with LinearAlgebra, which does not use where and instead defines the methods itself individually.

@vchuravy
Copy link
Member

vchuravy commented May 5, 2020

Ah darn. I think it is probably due to the P parameter, which makes them seem conflicting so maybe:

function LinearAlgebra.BLAS.trmm!(side::AbstractChar, uplo::AbstractChar, transa::AbstractChar, diag::AbstractChar, alpha::T, A::CuMatrix{T, P}, B::CuMatrix{T, P}) where {P, T <: CublasFloat}

But otherwise the @eval loop is fine.

@aterenin
Copy link
Contributor Author

aterenin commented May 5, 2020

Ah darn. I think it is probably due to the P parameter, which makes them seem conflicting so maybe:

function LinearAlgebra.BLAS.trmm!(side::AbstractChar, uplo::AbstractChar, transa::AbstractChar, diag::AbstractChar, alpha::T, A::CuMatrix{T, P}, B::CuMatrix{T, P}) where {P, T <: CublasFloat}

But otherwise the @eval loop is fine.

Just tried

LinearAlgebra.BLAS.trmm!(side::AbstractChar, uplo::AbstractChar, transa::AbstractChar, diag::AbstractChar, alpha::T, A::CuMatrix{T,P}, B::CuMatrix{T,P}) where {P,T<:CublasFloat}=
    CuArrays.CUBLAS.trmm!(side, uplo, transa, diag, alpha, A, B, B)
LinearAlgebra.BLAS.trsm!(side::AbstractChar, uplo::AbstractChar, transa::AbstractChar, diag::AbstractChar, alpha::T, A::CuMatrix{T,P}, B::CuMatrix{T,P}) where {P,T<:CublasFloat}=
    CuArrays.CUBLAS.trsm!(side, uplo, transa, diag, alpha, A, B)

and no luck.

ERROR: MethodError: trsm!(::Char, ::Char, ::Char, ::Char, ::Float32, ::CuArrays.CuArray{Float32,2,Nothing}, ::CuArrays.CuArray{Float32,2,Nothing}) is ambiguous. Candidates:
  trsm!(side::AbstractChar, uplo::AbstractChar, transa::AbstractChar, diag::AbstractChar, alpha::Float32, A::AbstractArray{Float32,2}, B::AbstractArray{Float32,2}) in LinearAlgebra.BLAS at /vol/bitbucket/at6617/julia-1.4.1/share/julia/stdlib/v1.4/LinearAlgebra/src/blas.jl:1690
  trsm!(side::AbstractChar, uplo::AbstractChar, transa::AbstractChar, diag::AbstractChar, alpha::T, A::CuArrays.CuArray{T,2,P}, B::CuArrays.CuArray{T,2,P}) where {P, T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64}} in SparseGaussianProcesses at /homes/at6617/SparseGaussianProcesses.jl/src/gpu.jl:46
Possible fix, define
  trsm!(::AbstractChar, ::AbstractChar, ::AbstractChar, ::AbstractChar, ::Float32, ::CuArrays.CuArray{Float32,2,P}, ::CuArrays.CuArray{Float32,2,P}) where P

@aterenin
Copy link
Contributor Author

aterenin commented May 6, 2020

Tests added, as well as a comment about why @eval is used, in case someone sees it later and wonders why it's there. Passes for me.

@maleadt
Copy link
Member

maleadt commented May 6, 2020

bors try

bors bot added a commit that referenced this pull request May 6, 2020
@bors
Copy link
Contributor

bors bot commented May 6, 2020

try

Build succeeded:

@maleadt maleadt merged commit 1943f38 into JuliaGPU:master May 6, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants