diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 11c3e4039..a680f77aa 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -61,6 +61,7 @@ export ResGatedGraphConv, SAGEConv, GMMConv, + EdgeWeightNorm, # layers/pool GlobalPool, diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 7bfbe9592..98ceb444d 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1181,3 +1181,69 @@ function Base.show(io::IO, l::GMMConv) l.residual==true || print(io, ", residual=", l.residual) print(io, ")") end + +@doc raw""" + EdgeWeightNorm(norm_both = true, eps = 0) + +Normalizes positive scalar edge weights on a graph following the form in GCN. + +norm_both = `true` yields the following normalization term: +```math +c_{ji} = (\sqrt{\sum_{k\in\mathcal{N}(j)}e_{jk}}\sqrt{\sum_{k\in\mathcal{N}(i)}e_{ki}}) +``` +norm_both = `false` yields the following normalization term: +```math +c_{ji} = (\sum_{k\in\mathcal{N}(i)}e_{ki}) +``` +where ``e_{ji}`` is the scalar weight on the edge from node j to node i. + +Return value is the normalized weight ``e_{ji} / c_{ji}`` for all edges in vector form. + +# Arguments + +- `norm_both`: The normalizer as specified above. Default is `true`. +- `eps`: Offset value in the denominator. Default is `0`. + +# Examples + +```julia +# create data +g = GNNGraph([1,2,3,4,3,6], [2,3,4,5,1,4]) +g = add_self_loops(g) + +# edge weights +edge_weights = [0.5, 0.6, 0.4, 0.7, 0.9, 0.1, 1, 1, 1, 1, 1, 1] + +l = EdgeWeightNorm() +l(g, edge_weights) +``` +""" +struct EdgeWeightNorm <: GNNLayer + norm_both::Bool + eps::Float64 +end + +@functor EdgeWeightNorm + +function EdgeWeightNorm(norm_both::Bool = true, + eps::Float64 = 0) + EdgeWeightNorm(norm_both, eps) +end + +function (l::EdgeWeightNorm)(g::GNNGraph, edge_weight::T) where T <: AbstractVector + norm_val = T() + edge_in, edge_out = edge_index(g) + + dg_in = degree(g; dir = :in, edge_weight) + dg_out = degree(g; dir = :out, edge_weight) + + for iter in 1:length(edge_weight) + if l.norm_both + push!(norm_val, edge_weight[iter] / (sqrt(dg_out[in[iter]] * dg_in[out[iter]]) + l.eps)) + else + push!(norm_val, edge_weight[iter] / (dg_in[out[iter]] + l.eps)) + end + end + + return norm_val +end