Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add induced_subgraph functionality #499

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ export rand_graph,

include("sampling.jl")
export sample_neighbors
export induced_subgraph

include("operators.jl")
# Base.intersect
Expand Down
50 changes: 50 additions & 0 deletions GNNGraphs/src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,53 @@ function sample_neighbors(g::GNNGraph{<:COO_T}, nodes, K = -1;
end
return gnew
end

"""
induced_subgraph(graph::GNNGraph, nodes::Vector{Int}) -> GNNGraph

Generates a subgraph from the original graph using the provided `nodes`.
The function includes the nodes' neighbors and creates edges between nodes that are connected in the original graph.
If a node has no neighbors, an isolated node will be added to the subgraph.

# Arguments:
- `graph::GNNGraph`: The original graph containing nodes, edges, and node features.
- `nodes::Vector{Int}`: A vector of node indices to include in the subgraph.

# Returns:
A new `GNNGraph` containing the subgraph with the specified nodes and their features.
"""
function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int})
if isempty(nodes)
return GNNGraph() # Return empty graph if no nodes are provided
end

node_map = Dict(node => i for (i, node) in enumerate(nodes))

# Collect edges to add
source = Int[]
target = Int[]
backup_gnn = GNNGraph()
for node in nodes
neighbors = Graphs.neighbors(graph, node, dir = :in)
if isempty(neighbors)
backup_gnn = add_nodes(backup_gnn, 1)
end
for neighbor in neighbors
if neighbor in keys(node_map)
push!(source, node_map[node])
push!(target, node_map[neighbor])
end
end
end

# Extract features for the new nodes
#new_features = graph.x[:, nodes]

if isempty(source) && isempty(target)
#backup_gnn.ndata.x = new_features ### TODO fix & add edges data (probably push themto the new vector?)
return backup_gnn # Return empty graph if no nodes are provided
end

return GNNGraph(source, target)
#, ndata = new_features) # Return the new GNNGraph with subgraph and features
end
16 changes: 16 additions & 0 deletions GNNGraphs/test/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,20 @@ if GRAPH_T == :coo
@test sg.ndata.x1 == g.ndata.x1[sg.ndata.NID]
@test length(union(sg.ndata.NID)) == length(sg.ndata.NID)
end

@testset "induced_subgraph" begin
# Create a simple GNNGraph with two nodes and one edge
s = [1]
t = [2]
### TODO add data
graph = GNNGraph((s, t))

# Induce subgraph on both nodes
nodes = [1, 2]
subgraph = induced_subgraph(graph, nodes)

@test subgraph.num_nodes == 2 # Subgraph should have 2 nodes
@test subgraph.num_edges == 1 # Subgraph should have 1 edge
### TODO @test subgraph.ndata.x == graph.x[:, nodes] # Features should match the original graph
end
end
Loading