Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add drop_nodes transform #426

Merged
merged 11 commits into from
Jun 27, 2024
1 change: 1 addition & 0 deletions src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ export add_nodes,
to_unidirected,
random_walk_pe,
remove_nodes,
drop_nodes,
# from Flux
batch,
unbatch,
Expand Down
32 changes: 32 additions & 0 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,38 @@ function remove_nodes(g::GNNGraph{<:COO_T}, nodes_to_remove::AbstractVector)
ndata, edata, g.gdata)
end

"""
drop_nodes(g::GNNGraph{<:COO_T}, p::Float32)
rbSparky marked this conversation as resolved.
Show resolved Hide resolved

Randomly drop nodes (and their associated edges) from a GNNGraph based on a given probability.
Dropping nodes is a technique that can be used for graph data augmentation, refering paper [DropNode](https://arxiv.org/pdf/2008.12578.pdf).

# Arguments
- `g`: The input graph from which nodes (and their associated edges) will be dropped.
- `p`: The probability of dropping each node. Default value is `0.5`.

# Returns
A modified GNNGraph with nodes (and their associated edges) dropped based on the given probability.

# Example
```julia
using GraphNeuralNetworks
# Construct a GNNGraph
g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1], num_nodes=3)
# Drop nodes with a probability of 0.5
g_new = drop_node(g, 0.5)
println(g_new)
```
"""
function drop_nodes(g::GNNGraph{<:COO_T}, p::Float32 = 0.5f)
rbSparky marked this conversation as resolved.
Show resolved Hide resolved
num_nodes = g.num_nodes
nodes_to_remove = filter(_ -> rand() < p, 1:num_nodes)

new_g = remove_nodes(g, nodes_to_remove)

return new_g
end

"""
add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata])
add_edges(g::GNNGraph, (s, t); [edata])
Expand Down
18 changes: 18 additions & 0 deletions test/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,24 @@ end
@test edata_new == edatatest
end end

@testset "drop_nodes" begin
if GRAPH_T == :coo
Random.seed!(42)
s = [1, 1, 2, 3]
t = [2, 3, 4, 5]
g = GNNGraph(s, t, graph_type = GRAPH_T)

gnew = drop_nodes(g, Float32(0.5))
@test gnew.num_nodes == 3

gnew = drop_nodes(g, Float32(1.0))
@test gnew.num_nodes == 0

gnew = drop_nodes(g, Float32(0.0))
@test gnew.num_nodes == 5
end
end

@testset "add_nodes" begin if GRAPH_T == :coo
g = rand_graph(6, 4, ndata = rand(2, 6), graph_type = GRAPH_T)
gnew = add_nodes(g, 5, ndata = ones(2, 5))
Expand Down
Loading