-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TypeError: 'tuple' object is not callable #188
Comments
Hello, you may try pytorchversion==1.10.0 to solve this problem. The code is: |
Another potential fix is to extract -op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
+op, _ = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') |
This worked for me in Pytorch 1.12.0, thank you :) |
The system runs normally after reinstalling pytorch version 1.10.0. thank you! CUDA 11.3conda install pytorch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 cudatoolkit=11.3 -c pytorch -c conda-forge |
This fixed it for me torch== 1.13.1 |
hi, i'm using colab. how could I apply this method? I tried the install cuda 1.10.0, it doesn't solve the issue for me. |
Adapt to newer _jit_get_operation API that changed in pytorch/pytorch#76814 for #188, #193
Should be fixed by c233a91. Sorry for the inconvenience. |
Sorry, I used this method and re-ran train.py, but the same error feedback still appeared.
|
Adapt to newer _jit_get_operation API that changed in pytorch/pytorch#76814 for NVlabs#188, NVlabs#193
The code does not work with any available stable version of pytorch, meaning it breaks with colab out of box. The fix is simple: From: NVlabs@c233a91 Adapt to newer _jit_get_operation API that changed in pytorch/pytorch#76814 for NVlabs#188, NVlabs#193
Describe the bug
Everything runs up until the point of trying to actually train then i get the error in the title "TypeError: 'tuple' object is not callable"
To Reproduce
Start training with a custom dataset 256x256
ERROR
The text was updated successfully, but these errors were encountered: