Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add NeighborLoader #497

Merged
merged 33 commits into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
349d99f
feat: init neighbor loader
askorupka Sep 11, 2024
07db4d0
feat: init neighbor loader
askorupka Sep 11, 2024
25945c7
feat: refine neighborloader
askorupka Sep 11, 2024
506d4c7
fix: refine neighborloader
askorupka Sep 11, 2024
991bf61
fix: refine neighborloader
askorupka Sep 11, 2024
c25bc1e
fix: refine neighborloader
askorupka Sep 11, 2024
fde10bb
fix: refine neighborloader
askorupka Sep 11, 2024
3656691
fix: refine neighborloader
askorupka Sep 11, 2024
9997fab
chore: add some comments
askorupka Sep 12, 2024
acf209c
chore: add TODO comments
askorupka Sep 17, 2024
0c4a653
feat: add tests, refine code
askorupka Sep 28, 2024
2035c5e
fix: add samplers.jl after rebase
askorupka Sep 28, 2024
ebebce9
chore: add docstrings
askorupka Sep 29, 2024
abf31cd
chore: Graphs to deps
askorupka Sep 29, 2024
bcdfa5e
chore: move using Graphs to main file
askorupka Sep 29, 2024
970d297
chore: readd Graphs to extras
askorupka Sep 29, 2024
b4c1ad7
chore: delete src/samplers.jl created by mistake
askorupka Sep 29, 2024
5e7544c
fix: add sampling.jl to docs
askorupka Oct 12, 2024
c9d412b
fix: add sampling.jl to docs
askorupka Oct 12, 2024
2d7bd0b
fix: add sampling.jl to docs
askorupka Oct 12, 2024
65aa564
fix: deduplicate function
askorupka Oct 12, 2024
61c5e39
fix: fix broken tests
askorupka Oct 24, 2024
aec5574
chore: remove printlns
askorupka Oct 24, 2024
e675086
Update GraphNeuralNetworks/src/samplers.jl
askorupka Oct 27, 2024
62f5d87
fix: remove docstrings where not needed
askorupka Oct 27, 2024
3ed22bf
chore: add ref to the paper
askorupka Oct 27, 2024
e4dc977
Update GraphNeuralNetworks/src/samplers.jl
askorupka Oct 29, 2024
962a97f
Update GraphNeuralNetworks/src/GraphNeuralNetworks.jl
askorupka Nov 2, 2024
d552de4
Update GraphNeuralNetworks/src/samplers.jl
askorupka Nov 2, 2024
6f26713
chore: add compat for Graphs
askorupka Nov 2, 2024
a4b6e15
refactor: allow input_nodes to be nothing
askorupka Nov 2, 2024
9af384f
chore: add loader iterate example to docstring
askorupka Nov 2, 2024
aa18520
fix: fix tests (docstring error)
askorupka Nov 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions GNNGraphs/src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int})

node_map = Dict(node => i for (i, node) in enumerate(nodes))

edge_list = [collect(t) for t in zip(edge_index(graph)[1],edge_index(graph)[2])]

# Collect edges to add
source = Int[]
target = Int[]
Expand All @@ -187,8 +189,7 @@ function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int})
if neighbor in keys(node_map)
push!(target, node_map[node])
push!(source, node_map[neighbor])

eindex = findfirst(x -> x == [neighbor, node], edge_index(graph))
eindex = findfirst(x -> x == [neighbor, node], edge_list)
push!(eindices, eindex)
end
end
Expand Down
1 change: 1 addition & 0 deletions GraphNeuralNetworks/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
askorupka marked this conversation as resolved.
Show resolved Hide resolved
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand Down
1 change: 1 addition & 0 deletions GraphNeuralNetworks/docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ makedocs(;
"Message Passing" => "api/messagepassing.md",
"Heterogeneous Graphs" => "api/heterograph.md",
"Temporal Graphs" => "api/temporalgraph.md",
"Samplers" => "api/samplers.md",
"Utils" => "api/utils.md",
],
"Developer Notes" => "dev.md",
Expand Down
14 changes: 14 additions & 0 deletions GraphNeuralNetworks/docs/src/api/samplers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
```@meta
CurrentModule = GraphNeuralNetworks
```

# Samplers


## Docs

```@autodocs
Modules = [GraphNeuralNetworks]
Pages = ["samplers.jl"]
Private = false
```
4 changes: 4 additions & 0 deletions GraphNeuralNetworks/src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using NNlib: scatter, gather
using ChainRulesCore
using Reexport
using MLUtils: zeros_like
using Graphs
askorupka marked this conversation as resolved.
Show resolved Hide resolved

using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
check_num_nodes, check_num_edges,
Expand Down Expand Up @@ -66,4 +67,7 @@ export GlobalPool,

include("deprecations.jl")

include("samplers.jl")
export NeighborLoader

end
110 changes: 110 additions & 0 deletions GraphNeuralNetworks/src/samplers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
struct NeighborLoader
askorupka marked this conversation as resolved.
Show resolved Hide resolved

A data structure for sampling neighbors from a graph for training Graph Neural Networks (GNNs).
It supports multi-layer sampling of neighbors for a batch of input nodes, useful for mini-batch training
originally introduced in "Inductive Representation Learning on Large Graphs" paper.
[see https://arxiv.org/abs/1706.02216]

# Fields
- `graph::GNNGraph`: The input graph.
- `num_neighbors::Vector{Int}`: A vector specifying the number of neighbors to sample per node at each GNN layer.
- `input_nodes::Vector{Int}`: A vector containing the starting nodes for neighbor sampling.
- `num_layers::Int`: The number of layers for neighborhood expansion (how far to sample neighbors).
- `batch_size::Union{Int, Nothing}`: The size of the batch. If not specified, it defaults to the number of `input_nodes`.

askorupka marked this conversation as resolved.
Show resolved Hide resolved
# Usage
```julia
loader = NeighborLoader(graph; num_neighbors=[10, 5], input_nodes=[1, 2, 3], num_layers=2)
```
askorupka marked this conversation as resolved.
Show resolved Hide resolved
"""
struct NeighborLoader
graph::GNNGraph # The input GNNGraph (graph + features from GraphNeuralNetworks.jl)
num_neighbors::Vector{Int} # Number of neighbors to sample per node, for each layer
input_nodes::Vector{Int} # Set of input nodes (starting nodes for sampling)
num_layers::Int # Number of layers for neighborhood expansion
batch_size::Union{Int, Nothing} # Optional batch size, defaults to the length of input_nodes if not given
neighbors_cache::Dict{Int, Vector{Int}} # Cache neighbors to avoid recomputation
end

function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}, num_layers::Int, batch_size::Union{Int, Nothing}=nothing)
askorupka marked this conversation as resolved.
Show resolved Hide resolved
return NeighborLoader(graph, num_neighbors, input_nodes, num_layers, batch_size === nothing ? length(input_nodes) : batch_size, Dict{Int, Vector{Int}}())
askorupka marked this conversation as resolved.
Show resolved Hide resolved
end

# Function to get cached neighbors or compute them
function get_neighbors(loader::NeighborLoader, node::Int)
if haskey(loader.neighbors_cache, node)
return loader.neighbors_cache[node]
else
neighbors = Graphs.neighbors(loader.graph, node, dir = :in) # Get neighbors from graph
loader.neighbors_cache[node] = neighbors
return neighbors
end
end

"""
sample_nbrs(loader::NeighborLoader, node::Int, layer::Int)

Samples a specified number of neighbors for the given `node` at a particular `layer` of the GNN.
The number of neighbors sampled is defined in `loader.num_neighbors`.

# Arguments:
- `loader::NeighborLoader`: The `NeighborLoader` instance.
- `node::Int`: The node to sample neighbors for.
- `layer::Int`: The current GNN layer (used to determine how many neighbors to sample).

# Returns:
A vector of sampled neighbor node indices.
"""
askorupka marked this conversation as resolved.
Show resolved Hide resolved
# Function to sample neighbors for a given node at a specific layer
function sample_nbrs(loader::NeighborLoader, node::Int, layer::Int)
neighbors = get_neighbors(loader, node)
if isempty(neighbors)
return Int[]
else
num_samples = min(loader.num_neighbors[layer], length(neighbors)) # Limit to required samples for this layer
return rand(neighbors, num_samples) # Randomly sample neighbors
end
end

# Iterator protocol for NeighborLoader with lazy batch loading
function Base.iterate(loader::NeighborLoader, state=1)
if state > length(loader.input_nodes)
return nothing # End of iteration if batches are exhausted (state larger than amount of input nodes or current batch no >= batch number)
end

# Determine the size of the current batch
batch_size = min(loader.batch_size, length(loader.input_nodes) - state + 1) # Conditional in case there is not enough nodes to fill the last batch
batch_nodes = loader.input_nodes[state:state + batch_size - 1] # Each mini-batch uses different set of input nodes

# Set for tracking the subgraph nodes
subgraph_nodes = Set(batch_nodes)

for node in batch_nodes
# Initialize current layer of nodes (starting with the node itself)
sampled_neighbors = Set([node])

# For each GNN layer, sample the neighborhood
for layer in 1:loader.num_layers
new_neighbors = Set{Int}()
for n in sampled_neighbors
neighbors = sample_nbrs(loader, n, layer) # Sample neighbors of the node for this layer
new_neighbors = union(new_neighbors, neighbors) # Avoid duplicates in the neighbor set
end
sampled_neighbors = new_neighbors
subgraph_nodes = union(subgraph_nodes, sampled_neighbors) # Expand the subgraph with the new neighbors
end
end

# Collect subgraph nodes and their features
subgraph_node_list = collect(subgraph_nodes)

if isempty(subgraph_node_list)
return GNNGraph(), state + batch_size
end

mini_batch_gnn = Graphs.induced_subgraph(loader.graph, subgraph_node_list) # Create a subgraph of the nodes

# Continue iteration for the next batch
return mini_batch_gnn, state + batch_size
end
1 change: 1 addition & 0 deletions GraphNeuralNetworks/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ tests = [
"layers/temporalconv",
"layers/pool",
"examples/node_classification_cora",
"samplers"
]

!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")
Expand Down
125 changes: 125 additions & 0 deletions GraphNeuralNetworks/test/samplers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Helper function to create a simple graph with node features using GNNGraph
function create_test_graph()
source = [1, 2, 3, 4] # Define source nodes of edges
target = [2, 3, 4, 5] # Define target nodes of edges
node_features = rand(Float32, 5, 5) # Create random node features (5 features for 5 nodes)

return GNNGraph(source, target, ndata = node_features) # Create a GNNGraph with edges and features
end

# Tests for NeighborLoader structure and its functionalities
@testset "NeighborLoader tests" begin

# 1. Basic functionality: Check neighbor sampling and subgraph creation
@testset "Basic functionality" begin
g = create_test_graph()

# Define NeighborLoader with 2 neighbors per layer, 2 layers, batch size 2
loader = NeighborLoader(g; num_neighbors=[2, 2], input_nodes=[1, 2], num_layers=2, batch_size=2)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph is not empty
@test !isempty(mini_batch_gnn.graph)

num_sampled_nodes = mini_batch_gnn.num_nodes
println("Number of nodes in mini-batch: ", num_sampled_nodes)

@test num_sampled_nodes == 2

# Test if there are edges in the subgraph
@test mini_batch_gnn.num_edges > 0
end

# 2. Edge case: Single node with no neighbors
@testset "Single node with no neighbors" begin
g = SimpleDiGraph(1) # A graph with a single node and no edges
node_features = rand(Float32, 5, 1)
graph = GNNGraph(g, ndata = node_features)

loader = NeighborLoader(graph; num_neighbors=[2], input_nodes=[1], num_layers=1)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph contains only one node
@test size(mini_batch_gnn.x, 2) == 1
end

# 3. Edge case: A node with no outgoing edges (isolated node)
@testset "Node with no outgoing edges" begin
g = SimpleDiGraph(2) # Graph with 2 nodes, no edges
node_features = rand(Float32, 5, 2)
graph = GNNGraph(g, ndata = node_features)

loader = NeighborLoader(graph; num_neighbors=[1], input_nodes=[1, 2], num_layers=1)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph contains the input nodes only (as no neighbors can be sampled)
@test size(mini_batch_gnn.x, 2) == 2 # Only two isolated nodes
end

# 4. Edge case: A fully connected graph
@testset "Fully connected graph" begin
g = SimpleDiGraph(3)
add_edge!(g, 1, 2)
add_edge!(g, 2, 3)
add_edge!(g, 3, 1)
node_features = rand(Float32, 5, 3)
graph = GNNGraph(g, ndata = node_features)

loader = NeighborLoader(graph; num_neighbors=[2, 2], input_nodes=[1], num_layers=2)

mini_batch_gnn, next_state = iterate(loader)

# Test if all nodes are included in the mini-batch since it's fully connected
@test size(mini_batch_gnn.x, 2) == 3 # All nodes should be included
end

# 5. Edge case: More layers than the number of neighbors
@testset "More layers than available neighbors" begin
g = SimpleDiGraph(3)
add_edge!(g, 1, 2)
add_edge!(g, 2, 3)
node_features = rand(Float32, 5, 3)
graph = GNNGraph(g, ndata = node_features)

# Test with 3 layers but only enough connections for 2 layers
loader = NeighborLoader(graph; num_neighbors=[1, 1, 1], input_nodes=[1], num_layers=3)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph contains all available nodes
@test size(mini_batch_gnn.x, 2) == 1
end

# 6. Edge case: Large batch size greater than the number of input nodes
@testset "Large batch size" begin
g = create_test_graph()

# Define NeighborLoader with a larger batch size than input nodes
loader = NeighborLoader(g; num_neighbors=[2], input_nodes=[1, 2], num_layers=1, batch_size=10)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph is not empty
@test !isempty(mini_batch_gnn.graph)

# Test if the correct number of nodes are sampled
@test size(mini_batch_gnn.x, 2) == length(unique([1, 2])) # Nodes [1, 2] are expected
end

# 7. Edge case: No neighbors sampled (num_neighbors = [0]) and 1 layer
@testset "No neighbors sampled" begin
g = create_test_graph()

# Define NeighborLoader with 0 neighbors per layer, 1 layer, batch size 2
loader = NeighborLoader(g; num_neighbors=[0], input_nodes=[1, 2], num_layers=1, batch_size=2)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph contains only the input nodes
@test size(mini_batch_gnn.x, 2) == 2 # No neighbors should be sampled, only nodes 1 and 2 should be in the graph
end

end
Loading