Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError: isolated node! hidden-tensor #73

Closed
turian opened this issue Jan 18, 2023 · 8 comments
Closed

AssertionError: isolated node! hidden-tensor #73

turian opened this issue Jan 18, 2023 · 8 comments

Comments

@turian
Copy link

turian commented Jan 18, 2023

Describe the bug

AssertionError: isolated node! hidden-tensor

On a particular model below

To Reproduce

pip install -e 'git+https://github.com/kkoutini/passt_hear21@0.0.17#egg=hear21passt' 
from hear21passt.base import get_basic_model,get_model_passt
import torch
# get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer
model = get_basic_model(mode="logits")
print(model.mel) # Extracts mel spectrogram from raw waveforms.

# optional replace the transformer with one that has the required number of classes i.e. 50
model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476",  n_classes=50)
print(model.net) # the transformer network.


# now model contains mel + the transformer pre-trained model ready to be fine tuned.
# It's still expecting input of the shape [batch, seconds*32000] sampling rate is 32k

model.train()

model_graph = draw_graph(
                model,
                input_size=(
                    4,
                    32000 * 10
                ),
                # # Graph left-to-right: https://github.com/mert-kurttutan/torchview/issues/56
                graph_dir="LR",
                depth=3,
                roll=True,
                expand_nested=True,
                # save_graph=True,
                # directory=tempdir,
            )
    model_graph = draw_graph(
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torchview/torchview.py", line 225, in draw_graph
    model_graph.fill_visual_graph()
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torchview/computation_graph.py", line 123, in fill_visual_graph
    self.render_nodes()
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torchview/computation_graph.py", line 132, in render_nodes
    self.traverse_graph(self.collect_graph, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torchview/computation_graph.py", line 186, in traverse_graph
    self.traverse_graph(action_fn, **new_kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torchview/computation_graph.py", line 186, in traverse_graph
    self.traverse_graph(action_fn, **new_kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torchview/computation_graph.py", line 186, in traverse_graph
    self.traverse_graph(action_fn, **new_kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torchview/computation_graph.py", line 155, in traverse_graph
    action_fn(**kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torchview/computation_graph.py", line 198, in collect_graph
    self.check_node(cur_node)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torchview/computation_graph.py", line 431, in check_node
    assert not node.is_leaf() or not node.is_root(), (
AssertionError: isolated node! hidden-tensor

Expected behavior

A graph should be produced.

Note that this model works with depth=1

@mert-kurttutan
Copy link
Owner

mert-kurttutan commented Jan 19, 2023

Hi,
The error is essentially because of the line here where they use .item() to get the value inside the tensor. This then leads to the value of tensor being used instead of tensor itself. This means that this tensor is not attached to the rest of computational graph as torch.Tensors would be.

There are two potential solutions I see:

  • Ignore these tensors
  • Keep these tensors

Which option do you think is more reasonable? It depends on whether you want to see this part of the code or not.
I will also drop the assertion, which is too strong, since this is a legitimate use in a pytorch code.

@mert-kurttutan
Copy link
Owner

mert-kurttutan commented Jan 19, 2023

Just for documentation purposes, the reason it worked for depth=1 is because when depth=1, these tensors are not even reached since their depth=2 is larger than depth=1.

@turian
Copy link
Author

turian commented Jan 19, 2023

@mert-kurttutan I would be happy with either. If ignore is easier, then perhaps just do ignore and log a warning to the user?

@mert-kurttutan
Copy link
Owner

mert-kurttutan commented Jan 20, 2023

Actually, these cases correspond to very small part of network. For this to happen, there has to be no operation on the tensor whose item method being used. I think ignoring them seems reasonable without any message.
For instance,

fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item()

Tensor is created and right after this, item is applied. So only the tensor creation part will be ignored. Since this type of part does not contain any operation and is not crucial, I will ignore and not send any warning

@turian
Copy link
Author

turian commented Jan 20, 2023

Sounds great to me.

@mert-kurttutan
Copy link
Owner

Can you try now using the main branch of the repo? After this, I can release the new version.

@turian
Copy link
Author

turian commented Jan 21, 2023

@mert-kurttutan I'm traveling right now so I cannot check immediately.

@turian
Copy link
Author

turian commented Jan 23, 2023

On SHA 00bd35b

it works!

image

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants