diff --git a/Project.toml b/Project.toml index b23a29603..4785fe5dc 100644 --- a/Project.toml +++ b/Project.toml @@ -24,9 +24,11 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" [extensions] GraphNeuralNetworksCUDAExt = "CUDA" +GraphNeuralNetworksSimpleWeightedGraphsExt = "SimpleWeightedGraphs" [compat] Adapt = "3, 4" @@ -45,6 +47,7 @@ NNlib = "0.9" NearestNeighbors = "0.4" Random = "1" Reexport = "1" +SimpleWeightedGraphs = "1.4.0" SparseArrays = "1" Statistics = "1" StatsBase = "0.34" @@ -59,9 +62,10 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["Test", "Adapt", "DataFrames", "InlineStrings", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets", "CUDA", "cuDNN"] +test = ["Test", "Adapt", "DataFrames", "InlineStrings", "SimpleWeightedGraphs", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets", "CUDA", "cuDNN"] diff --git a/ext/GraphNeuralNetworksSimpleWeightedGraphsExt/GraphNeuralNetworksSimpleWeightedGraphsExt.jl b/ext/GraphNeuralNetworksSimpleWeightedGraphsExt/GraphNeuralNetworksSimpleWeightedGraphsExt.jl new file mode 100644 index 000000000..aabc13443 --- /dev/null +++ b/ext/GraphNeuralNetworksSimpleWeightedGraphsExt/GraphNeuralNetworksSimpleWeightedGraphsExt.jl @@ -0,0 +1,12 @@ +module GraphNeuralNetworksSimpleWeightedGraphsExt + +using GraphNeuralNetworks +using Graphs +using SimpleWeightedGraphs + +function GraphNeuralNetworks.GNNGraph(g::T; kws...) where + {T <: Union{SimpleWeightedGraph, SimpleWeightedDiGraph}} + return GNNGraph(g.weights, kws...) +end + +end #module \ No newline at end of file diff --git a/test/GNNGraphs/generate.jl b/test/GNNGraphs/generate.jl index 675e3539e..d9f281fb2 100644 --- a/test/GNNGraphs/generate.jl +++ b/test/GNNGraphs/generate.jl @@ -119,4 +119,4 @@ end R = 10 tg1 = rand_temporal_hyperbolic_graph(number_nodes, number_snapshots; α, R, speed, ζ) @test mean(mean(degree.(tg1.snapshots)))<=mean(mean(degree.(tg.snapshots))) -end \ No newline at end of file +end diff --git a/test/ext/GraphNeuralNetworksSimpleWeightedGraphsExt/GraphNeuralNetworksSimpleWeightedGraphsExt.jl b/test/ext/GraphNeuralNetworksSimpleWeightedGraphsExt/GraphNeuralNetworksSimpleWeightedGraphsExt.jl new file mode 100644 index 000000000..254498999 --- /dev/null +++ b/test/ext/GraphNeuralNetworksSimpleWeightedGraphsExt/GraphNeuralNetworksSimpleWeightedGraphsExt.jl @@ -0,0 +1,11 @@ +@testset "simple_weighted_graph" begin + srcs = [1, 2, 1] + dsts = [2, 3, 3] + wts = [0.5, 0.8, 2.0] + g = SimpleWeightedGraph(srcs, dsts, wts) + gd = SimpleWeightedDiGraph(srcs, dsts, wts) + gnn_g = GNNGraph(g) + gnn_gd = GNNGraph(gd) + @test get_edge_weight(gnn_g) == [0.5, 2, 0.5, 0.8, 2.0, 0.8] + @test get_edge_weight(gnn_gd) == [0.5, 2, 0.8] +end diff --git a/test/runtests.jl b/test/runtests.jl index 271373ecc..85e26ac38 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,7 @@ using Zygote using Test using MLDatasets using InlineStrings # not used but with the import we test #98 and #104 +using SimpleWeightedGraphs CUDA.allowscalar(false) @@ -46,6 +47,7 @@ tests = [ "mldatasets", "examples/node_classification_cora", "deprecations", + "ext/GraphNeuralNetworksSimpleWeightedGraphsExt/GraphNeuralNetworksSimpleWeightedGraphsExt" ] !CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")