You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I’m trying to model Graph NODEs integrating GraphNeuralNetworks.jl and OrdinaryDiffEq.jl. I am trying to learn both the neural network parameters as well as the weights of the edges, so I have to manually modify the Flux parameters during prediction. When I run the following MWE:
using Graphs, GraphNeuralNetworks, Flux, OrdinaryDiffEq, ComponentArrays, Zygote, SciMLSensitivity
time =1:10
x0 =rand(9)
obs =rand(9,10)
fullGraph =GNNGraph(complete_digraph(3))
layer1 =GCNConv(3=>10,tanh,use_edge_weight=true)
layer2 =GCNConv(10=>3,use_edge_weight=true)
chain =GNNChain(layer1,layer2)
pinit =ComponentArray{Float32}(weights =rand(ne(fullGraph)),
layer1 =f64(layer1.weight),layer2 =f64(layer2.weight))
functionpredict(p)
fullGraph =GNNGraph(complete_digraph(3))
fullGraph =set_edge_weight(fullGraph,p.weights)
chain.layers[1].weight .= p.layer1
chain.layers[2].weight .= p.layer2
functionnn!(du,u,p,t)
uGraph =reshape(u,(3,3))
dGraph =reshape(chain(fullGraph,uGraph),(3*3))
du .= dGraph
end
prob =ODEProblem(nn!,x0,(time[1],time[end]),saveat=time)
sol =solve(prob)
returnArray(sol)
endfunctionloss_function(p)
pred =predict(p)
sum(abs2,pred .- obs)
end
Zygote.gradient(loss_function,pinit)
I'm crossposting this from the discourse as I don't know if this is necessarily a bug with GraphNeuralNetworks.jl or if the devs know a better alternative to do these kinds of processes.
Thanks!
The text was updated successfully, but these errors were encountered:
I’m trying to model Graph NODEs integrating
GraphNeuralNetworks.jl
andOrdinaryDiffEq.jl
. I am trying to learn both the neural network parameters as well as the weights of the edges, so I have to manually modify the Flux parameters during prediction. When I run the following MWE:I get the following error:
I'm crossposting this from the discourse as I don't know if this is necessarily a bug with
GraphNeuralNetworks.jl
or if the devs know a better alternative to do these kinds of processes.Thanks!
The text was updated successfully, but these errors were encountered: