-
Notifications
You must be signed in to change notification settings - Fork 334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] torch_geometric modules not working in objectives #1576
Comments
I'd like to train a GNN for multi-agent cooperation in TorchRL, but this is a blocker for me. It would be great if torch_geometric would work for this use case! |
Nit: ...
loss_1 = DQNLoss(value_network=value, action_space="one_hot")
... and the data cannot be passed to the loss (it lacks actions, done states etc) so it's a bit hard to test anything with it. If fixed I believe this should work under #1711 |
Right so I can confirm that following #1711, #1707 and pytorch/tensordict#579 the following code runs import torch
import torch_geometric
from torchrl.modules import QValueActor
from torchrl.objectives import DQNLoss
from tensordict import TensorDict
model = torch_geometric.nn.Linear(2, 3)
value = QValueActor(module=model, in_keys="obs", action_space="one_hot")
loss_1 = DQNLoss(value_network=value, action_space="one_hot")
data = TensorDict({"obs": torch.zeros(3, 2), 'action': torch.zeros(3, 3),
'next': {"obs": torch.zeros(3, 2), "reward": torch.zeros(3, 1), "done": torch.zeros(3, 1).bool(),
"terminated": torch.zeros(3, 1).bool(), "truncated": torch.zeros(3, 1).bool(),} }, batch_size=[3])
loss_1(data) |
If you try using modules from
torch_geometric.nn
in a TorchRL loss, an exception will be raisedpip install torch_geometric
I am not able to understand the reason why this is happening as the model in this example is extremely simple.
https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/dense/linear.html#Linear
I think this is quite important as torch_gemoetric is the leading library for GNNs in pytorch.
The text was updated successfully, but these errors were encountered: