Skip to content

Commit

Permalink
feat: support weights when generating from SimpleWeightedGraph (#371)
Browse files Browse the repository at this point in the history
* add simpleweightedgraph support

* add test

* add SimpleWeightedGraphs to runtests.jl

* replace import for using as it's not python

* change to extension

* remove SimpleWeightedGraphs from deps

* add PR review suggestions

* add test

* refinement

* add PR review suggestions
  • Loading branch information
askorupka authored Mar 7, 2024
1 parent 19af4ec commit 91ddd53
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 2 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"

[extensions]
GraphNeuralNetworksCUDAExt = "CUDA"
GraphNeuralNetworksSimpleWeightedGraphsExt = "SimpleWeightedGraphs"

[compat]
Adapt = "3, 4"
Expand All @@ -45,6 +47,7 @@ NNlib = "0.9"
NearestNeighbors = "0.4"
Random = "1"
Reexport = "1"
SimpleWeightedGraphs = "1.4.0"
SparseArrays = "1"
Statistics = "1"
StatsBase = "0.34"
Expand All @@ -59,9 +62,10 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["Test", "Adapt", "DataFrames", "InlineStrings", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets", "CUDA", "cuDNN"]
test = ["Test", "Adapt", "DataFrames", "InlineStrings", "SimpleWeightedGraphs", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets", "CUDA", "cuDNN"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module GraphNeuralNetworksSimpleWeightedGraphsExt

using GraphNeuralNetworks
using Graphs
using SimpleWeightedGraphs

function GraphNeuralNetworks.GNNGraph(g::T; kws...) where
{T <: Union{SimpleWeightedGraph, SimpleWeightedDiGraph}}
return GNNGraph(g.weights, kws...)
end

end #module
2 changes: 1 addition & 1 deletion test/GNNGraphs/generate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,4 @@ end
R = 10
tg1 = rand_temporal_hyperbolic_graph(number_nodes, number_snapshots; α, R, speed, ζ)
@test mean(mean(degree.(tg1.snapshots)))<=mean(mean(degree.(tg.snapshots)))
end
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
@testset "simple_weighted_graph" begin
srcs = [1, 2, 1]
dsts = [2, 3, 3]
wts = [0.5, 0.8, 2.0]
g = SimpleWeightedGraph(srcs, dsts, wts)
gd = SimpleWeightedDiGraph(srcs, dsts, wts)
gnn_g = GNNGraph(g)
gnn_gd = GNNGraph(gd)
@test get_edge_weight(gnn_g) == [0.5, 2, 0.5, 0.8, 2.0, 0.8]
@test get_edge_weight(gnn_gd) == [0.5, 2, 0.8]
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ using Zygote
using Test
using MLDatasets
using InlineStrings # not used but with the import we test #98 and #104
using SimpleWeightedGraphs

CUDA.allowscalar(false)

Expand Down Expand Up @@ -46,6 +47,7 @@ tests = [
"mldatasets",
"examples/node_classification_cora",
"deprecations",
"ext/GraphNeuralNetworksSimpleWeightedGraphsExt/GraphNeuralNetworksSimpleWeightedGraphsExt"
]

!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")
Expand Down

0 comments on commit 91ddd53

Please sign in to comment.