-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove CUDA dependence in favor of extension (#318)
* cuda extension * fix
- Loading branch information
1 parent
f59ce44
commit 92d3163
Showing
14 changed files
with
84 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|
||
GNNGraphs._rand_dense_vector(A::CUMAT_T) = CUDA.randn(size(A, 1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|
||
GNNGraphs.dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
|
||
GNNGraphs.iscuarray(x::AnyCuArray) = true | ||
|
||
|
||
function sort_edge_index(u::AnyCuArray, v::AnyCuArray) | ||
#TODO proper cuda friendly implementation | ||
sort_edge_index(u |> Flux.cpu, v |> Flux.cpu) |> Flux.gpu | ||
end |
17 changes: 17 additions & 0 deletions
17
ext/GraphNeuralNetworksCUDAExt/GraphNeuralNetworksCUDAExt.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
module GraphNeuralNetworksCUDAExt | ||
|
||
using CUDA | ||
using Random, Statistics, LinearAlgebra | ||
using GraphNeuralNetworks | ||
using GraphNeuralNetworks.GNNGraphs | ||
using GraphNeuralNetworks.GNNGraphs: COO_T, ADJMAT_T, SPARSE_T | ||
import GraphNeuralNetworks: propagate | ||
|
||
const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix} | ||
|
||
include("GNNGraphs/query.jl") | ||
include("GNNGraphs/transform.jl") | ||
include("GNNGraphs/utils.jl") | ||
include("msgpass.jl") | ||
|
||
end #module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
|
||
###### PROPAGATE SPECIALIZATIONS #################### | ||
|
||
## COPY_XJ | ||
|
||
## avoid the fast path on gpu until we have better cuda support | ||
function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), | ||
xi, xj::AnyCuMatrix, e) | ||
propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e) | ||
end | ||
|
||
## E_MUL_XJ | ||
|
||
## avoid the fast path on gpu until we have better cuda support | ||
function propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), | ||
xi, xj::AnyCuMatrix, e::AbstractVector) | ||
propagate((xi, xj, e) -> e_mul_xj(xi, xj, e), g, +, xi, xj, e) | ||
end | ||
|
||
## W_MUL_XJ | ||
|
||
## avoid the fast path on gpu until we have better cuda support | ||
function propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), | ||
xi, xj::AnyCuMatrix, e::Nothing) | ||
propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e) | ||
end | ||
|
||
# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e) | ||
# A = adjacency_matrix(g, weighted=false) | ||
# D = compute_degree(A) | ||
# return xj * A * D | ||
# end | ||
|
||
# # Zygote bug. Error with sparse matrix without nograd | ||
# compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2))) | ||
|
||
# Flux.Zygote.@nograd compute_degree |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters