Skip to content

Commit

Permalink
Merge branch 'master' into qr_views
Browse files Browse the repository at this point in the history
  • Loading branch information
evelyne-ringoot committed Mar 10, 2023
2 parents d9ca07a + 96a4d0c commit 05ba598
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 13 deletions.
26 changes: 13 additions & 13 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.2.1"

[[Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "0310e08cb19f5da31d08341c6120c047598f5b9c"
deps = ["LinearAlgebra", "Requires"]
git-tree-sha1 = "cc37d689f599e8df4f464b2fa3870ff7db7492ef"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "3.5.0"
version = "3.6.1"

[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
Expand Down Expand Up @@ -58,9 +58,9 @@ version = "1.15.7"

[[ChangesOfVariables]]
deps = ["ChainRulesCore", "LinearAlgebra", "Test"]
git-tree-sha1 = "844b061c104c408b24537482469400af6075aae4"
git-tree-sha1 = "485193efd2176b88e6622a39a246f8c5b600e74e"
uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
version = "0.1.5"
version = "0.1.6"

[[Compat]]
deps = ["Dates", "LinearAlgebra", "UUIDs"]
Expand Down Expand Up @@ -124,9 +124,9 @@ uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
version = "0.1.8"

[[IrrationalConstants]]
git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151"
git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
version = "0.1.1"
version = "0.2.2"

[[JLLWrappers]]
deps = ["Preferences"]
Expand All @@ -142,9 +142,9 @@ version = "4.16.0"

[[LLVMExtra_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
git-tree-sha1 = "771bfe376249626d3ca12bcd58ba243d3f961576"
git-tree-sha1 = "7718cf44439c676bc0ec66a87099f41015a522d6"
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
version = "0.0.16+0"
version = "0.0.16+2"

[[LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
Expand Down Expand Up @@ -175,9 +175,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[LogExpFunctions]]
deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "680e733c3a0a9cea9e935c8c2184aea6a63fa0b5"
git-tree-sha1 = "0a1b7c2863e44523180fdb3146534e265a91870b"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.21"
version = "0.3.23"

[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand Down Expand Up @@ -266,9 +266,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
git-tree-sha1 = "d75bda01f8c31ebb72df80a46c88b25d1c79c56d"
git-tree-sha1 = "ef28127915f4229c971eb43f3fc075dd3fe91880"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "2.1.7"
version = "2.2.0"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
Expand Down
12 changes: 12 additions & 0 deletions lib/cublas/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -591,3 +591,15 @@ end
error("only supports BLAS type, got $T")
end
end

op_wrappers = ((identity, T -> 'N', identity),
(T -> :(Transpose{T, <:$T}), T -> 'T', A -> :(parent($A))),
(T -> :(Adjoint{T, <:$T}), T -> T <: Real ? 'T' : 'C', A -> :(parent($A))))

for op in (:(+), :(-))
for (wrapa, transa, unwrapa) in op_wrappers, (wrapb, transb, unwrapb) in op_wrappers
TypeA = wrapa(:(CuMatrix{T}))
TypeB = wrapb(:(CuMatrix{T}))
@eval Base.$op(A::$TypeA, B::$TypeB) where {T <: CublasFloat} = CUBLAS.geam($transa(T), $transb(T), one(T), $(unwrapa(:A)), $(op)(one(T)), $(unwrapb(:B)))
end
end
21 changes: 21 additions & 0 deletions test/cublas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,27 @@ end
h_C = Array(d_C)
@test D h_C
end
@testset "CuMatrix -- A ± B -- $elty" begin
for opa in (identity, transpose, adjoint)
for opb in (identity, transpose, adjoint)
n = 10
m = 20
geam_A = opa == identity ? rand(elty, n, m) : rand(elty, m, n)
geam_B = opb == identity ? rand(elty, n, m) : rand(elty, m, n)

geam_dA = CuMatrix{elty}(geam_A)
geam_dB = CuMatrix{elty}(geam_B)

geam_C = opa(geam_A) + opb(geam_B)
geam_dC = opa(geam_dA) + opb(geam_dB)
@test geam_C collect(geam_dC)

geam_C = opa(geam_A) - opb(geam_B)
geam_dC = opa(geam_dA) - opb(geam_dB)
@test geam_C collect(geam_dC)
end
end
end
A = rand(elty,m,k)
d_A = CuArray(A)
@testset "syrkx!" begin
Expand Down

0 comments on commit 05ba598

Please sign in to comment.