Skip to content

Commit

Permalink
Add TemporalSnapshotsGNNGraph struct (#293)
Browse files Browse the repository at this point in the history
* Add `TemporalSnapshotsGNNGraph` struct

* Remove typo

* Add `==` function

* Add `add/remove_snaposhot` and `show` functions

* Add test

* Export `TemporalSnapshotsGNNgraph` functions

* Add temporalsnapshotsgnngraph tests

* Rename file and function

* Fix comma

* Add test `show`
  • Loading branch information
aurorarossi authored Jun 9, 2023
1 parent d394b91 commit 65c0faa
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ export GNNGraph,
include("gnnheterograph.jl")
export GNNHeteroGraph


include("temporalsnapshotsgnngraph.jl")
export TemporalSnapshotsGNNGraph,
add_snapshot,
remove_snapshot

include("query.jl")
export adjacency_list,
edge_index,
Expand Down
103 changes: 103 additions & 0 deletions src/GNNGraphs/temporalsnapshotsgnngraph.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
struct TemporalSnapshotsGNNGraph
num_nodes::Vector{Int}
num_edges::Vector{Int}
num_snapshots::Int
snapshots::Vector{<:GNNGraph}
tgdata::DataStore
end

function TemporalSnapshotsGNNGraph(snapshots::AbstractVector{<:GNNGraph})
@assert all([s.num_nodes == snapshots[1].num_nodes for s in snapshots]) "all snapshots must have the same number of nodes"
return TemporalSnapshotsGNNGraph(
[s.num_nodes for s in snapshots],
[s.num_edges for s in snapshots],
length(snapshots),
snapshots,
DataStore()
)
end

function Base.:(==)(tsg1::TemporalSnapshotsGNNGraph, tsg2::TemporalSnapshotsGNNGraph)
tsg1 === tsg2 && return true
for k in fieldnames(typeof(tsg1))
getfield(tsg1, k) != getfield(tsg2, k) && return false
end
return true
end

function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::Int)
return tg.snapshots[t]
end

function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::AbstractVector)
return TemporalSnapshotsGNNGraph(tg.num_nodes[t], tg.num_edges[t], length(t), tg.snapshots[t], tg.tgdata)
end

function add_snapshot(tg::TemporalSnapshotsGNNGraph, t::Int, g::GNNGraph)
@assert g.num_nodes == tg.num_nodes[t] "number of nodes must match"
num_nodes= tg.num_nodes
num_edges = tg.num_edges
snapshots = tg.snapshots
num_snapshots = tg.num_snapshots + 1
insert!(num_nodes, t, g.num_nodes)
insert!(num_edges, t, g.num_edges)
insert!(snapshots, t, g)
return TemporalSnapshotsGNNGraph(num_nodes, num_edges, num_snapshots, snapshots, tg.tgdata)
end

function remove_snapshot(tg::TemporalSnapshotsGNNGraph, t::Int)
num_nodes= tg.num_nodes
num_edges = tg.num_edges
snapshots = tg.snapshots
num_snapshots = tg.num_snapshots - 1
deleteat!(num_nodes, t)
deleteat!(num_edges, t)
deleteat!(snapshots, t)
return TemporalSnapshotsGNNGraph(num_nodes, num_edges, num_snapshots, snapshots, tg.tgdata)
end

function Base.show(io::IO, tsg::TemporalSnapshotsGNNGraph)
print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ")
print_feature_t(io, tsg.tgdata)
print(io, " data")
end

function Base.show(io::IO, ::MIME"text/plain", tsg::TemporalSnapshotsGNNGraph)
if get(io, :compact, false)
print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ")
print_feature_t(io, tsg.tgdata)
print(io, " data")
else
print(io,
"TemporalSnapshotsGNNGraph:\n num_nodes: $(tsg.num_nodes)\n num_edges: $(tsg.num_edges)\n num_snapshots: $(tsg.num_snapshots)")
if !isempty(tsg.tgdata)
print(io, "\n tgdata:")
for k in keys(tsg.tgdata)
print(io, "\n\t$k = $(shortsummary(tsg.tgdata[k]))")
end
end
end
end


function print_feature_t(io::IO, feature)
if !isempty(feature)
if length(keys(feature)) == 1
k = first(keys(feature))
v = first(values(feature))
print(io, "$(k): $(dims2string(size(v)))")
else
print(io, "(")
for (i, (k, v)) in enumerate(pairs(feature))
print(io, "$k: $(dims2string(size(v)))")
if i == length(feature)
print(io, ")")
else
print(io, ", ")
end
end
end
else
print(io, "no")
end
end
51 changes: 51 additions & 0 deletions test/GNNGraphs/temporalsnapshotsgnngraph.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
@testset "Constructor array TemporalSnapshotsGNNGraph" begin
snapshots = [rand_graph(10, 20) for i in 1:5]
tsg = TemporalSnapshotsGNNGraph(snapshots)
@test tsg.num_nodes == [10 for i in 1:5]
@test tsg.num_edges == [20 for i in 1:5]
wrsnapshots = [rand_graph(10,20), rand_graph(12,22)]
@test_throws AssertionError TemporalSnapshotsGNNGraph(wrsnapshots)
end

@testset "==" begin
snapshots = [rand_graph(10, 20) for i in 1:5]
tsg1 = TemporalSnapshotsGNNGraph(snapshots)
tsg2 = TemporalSnapshotsGNNGraph(snapshots)
@test tsg1 == tsg2
tsg3 = TemporalSnapshotsGNNGraph(snapshots[1:3])
@test tsg1 != tsg3
@test tsg1 !== tsg3
end

@testset "getindex" begin
snapshots = [rand_graph(10, 20) for i in 1:5]
tsg = TemporalSnapshotsGNNGraph(snapshots)
@test tsg[3] == snapshots[3]
@test tsg[[1,2]] == TemporalSnapshotsGNNGraph([10,10], [20,20], 2, snapshots[1:2], tsg.tgdata)
end

@testset "add/remove_snapshot" begin
snapshots = [rand_graph(10, 20) for i in 1:5]
tsg = TemporalSnapshotsGNNGraph(snapshots)
g = rand_graph(10, 20)
tsg = add_snapshot(tsg, 3, g)
@test tsg.num_nodes == [10 for i in 1:6]
@test tsg.num_edges == [20 for i in 1:6]
@test tsg.snapshots[3] == g
tsg = remove_snapshot(tsg, 3)
@test tsg.num_nodes == [10 for i in 1:5]
@test tsg.num_edges == [20 for i in 1:5]
@test tsg.snapshots == snapshots
end

@testset "show" begin
snapshots = [rand_graph(10, 20) for i in 1:5]
tsg = TemporalSnapshotsGNNGraph(snapshots)
@test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with no data"
@test sprint(show, MIME("text/plain"), tsg; context=:compact => true) == "TemporalSnapshotsGNNGraph(5) with no data"
@test sprint(show, MIME("text/plain"), tsg; context=:compact => false) == "TemporalSnapshotsGNNGraph:\n num_nodes: [10, 10, 10, 10, 10]\n num_edges: [20, 20, 20, 20, 20]\n num_snapshots: 5"
tsg.tgdata.x=rand(4)
@test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with x: 4-element data"
end

# @test sprint(show, MIME("text/plain"), rand_graph(10, 20); context=:compact => true) == "GNNGraph(10, 20) with no data"
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ tests = [
"GNNGraphs/query",
"GNNGraphs/sampling",
"GNNGraphs/gnnheterograph",
"GNNGraphs/temporalsnapshotsgnngraph",
"utils",
"msgpass",
"layers/basic",
Expand Down

0 comments on commit 65c0faa

Please sign in to comment.