Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] torch_geometric modules not working in objectives #1576

Closed
matteobettini opened this issue Sep 26, 2023 · 3 comments
Closed

[BUG] torch_geometric modules not working in objectives #1576

matteobettini opened this issue Sep 26, 2023 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@matteobettini
Copy link
Contributor

matteobettini commented Sep 26, 2023

If you try using modules from torch_geometric.nn in a TorchRL loss, an exception will be raised

pip install torch_geometric

import torch
import torch_geometric
from torchrl.modules import QValueActor
from torchrl.objectives import DQNLoss
from tensordict import TensorDict

model = torch_geometric.nn.Linear(2, 3)
value = QValueActor(module=model, in_keys="obs", action_space="one_hot")
loss_1 = DQNLoss(value_network=model, action_space="one_hot")

data = TensorDict({"obs": torch.zeros(3, 2)}, batch_size=[])
loss_1(data)
Traceback (most recent call last):
  File "/Users/matbet/PycharmProjects/rl/prova.py", line 15, in <module>
    loss_1(data)
  File "/Users/matbet/miniconda3/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/Users/matbet/PycharmProjects/tensordict/tensordict/_contextlib.py", line 126, in decorate_context
    return func(*args, **kwargs)
  File "/Users/matbet/PycharmProjects/tensordict/tensordict/nn/common.py", line 282, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/Users/matbet/PycharmProjects/rl/torchrl/objectives/dqn.py", line 282, in forward
    self.value_network(
  File "/Users/matbet/miniconda3/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/matbet/PycharmProjects/tensordict/tensordict/nn/functional_modules.py", line 572, in new_fun
    old_params = _assign_params(
  File "/Users/matbet/PycharmProjects/tensordict/tensordict/nn/functional_modules.py", line 649, in _assign_params
    return _swap_state(module, params, make_stateless, return_old_tensordict)
  File "/Users/matbet/PycharmProjects/tensordict/tensordict/nn/functional_modules.py", line 389, in _swap_state
    _old_value = _swap_state(
  File "/Users/matbet/PycharmProjects/tensordict/tensordict/nn/functional_modules.py", line 389, in _swap_state
    _old_value = _swap_state(
  File "/Users/matbet/PycharmProjects/tensordict/tensordict/nn/functional_modules.py", line 378, in _swap_state
    raise Exception(f"{model}\nhas no stateless attribute.")
Exception: Linear(2, 3, bias=False)
has no stateless attribute.

I am not able to understand the reason why this is happening as the model in this example is extremely simple.
https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/dense/linear.html#Linear

I think this is quite important as torch_gemoetric is the leading library for GNNs in pytorch.

@janblumenkamp
Copy link

I'd like to train a GNN for multi-agent cooperation in TorchRL, but this is a blocker for me. It would be great if torch_geometric would work for this use case!

@vmoens
Copy link
Contributor

vmoens commented Nov 26, 2023

Nit:
I think this code should read

...
loss_1 = DQNLoss(value_network=value, action_space="one_hot")
...

and the data cannot be passed to the loss (it lacks actions, done states etc) so it's a bit hard to test anything with it.

If fixed I believe this should work under #1711

@vmoens
Copy link
Contributor

vmoens commented Nov 27, 2023

Right so I can confirm that following #1711, #1707 and pytorch/tensordict#579 the following code runs

import torch
import torch_geometric
from torchrl.modules import QValueActor
from torchrl.objectives import DQNLoss
from tensordict import TensorDict

model = torch_geometric.nn.Linear(2, 3)
value = QValueActor(module=model, in_keys="obs", action_space="one_hot")

loss_1 = DQNLoss(value_network=value, action_space="one_hot")

data = TensorDict({"obs": torch.zeros(3, 2), 'action': torch.zeros(3, 3), 
                   'next': {"obs": torch.zeros(3, 2), "reward": torch.zeros(3, 1), "done": torch.zeros(3, 1).bool(), 
                            "terminated": torch.zeros(3, 1).bool(), "truncated": torch.zeros(3, 1).bool(),} }, batch_size=[3])
loss_1(data)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants