Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error While Starting 2nd Epoch #6002

Closed
mfatih7 opened this issue Dec 3, 2023 · 33 comments · Fixed by #6123
Closed

Error While Starting 2nd Epoch #6002

mfatih7 opened this issue Dec 3, 2023 · 33 comments · Fixed by #6123

Comments

@mfatih7
Copy link

mfatih7 commented Dec 3, 2023

Hello

In one of my experiments, training and validation operations on both TPUv2(single core) and TPUv3(single core) devices finished successfully for the 1st chunk of the first epoch.
At the start of the 2nd chunk, I got the error below:

Device is TPU:0
2023-12-03 18:43:08.088299: F ./torch_xla/csrc/runtime/debug_macros.h:20] Non-OK-status: status.status() status: INVALID_ARGUMENT: Transpose dimensions [0,1,-2,-1] are not a permutation of the operand dimensions (operand shape is f32[9,9]).
*** Begin stack trace ***
        tsl::CurrentStackTrace()
        xla::Shape const* ConsumeValue<xla::Shape const*>(absl::lts_20230125::StatusOr<xla::Shape const*>&&)
        torch_xla::ShapeHelper::ShapeOfXlaOp(xla::XlaOp)
        torch_xla::BuildDiagonalViewUpdate(xla::XlaOp, xla::XlaOp, long, long, long)
        torch_xla::DiagonalViewUpdate::Lower(torch_xla::LoweringContext*) const
        torch_xla::LoweringContext::LowerNode(torch::lazy::Node const*)
        torch_xla::LoweringContext::LoweringContext(std::string const&, torch::lazy::BackendDevice, c10::ArrayRef<torch::lazy::Node const*>, std::unordered_map<torch::lazy::Node const*, torch::lazy::Util::EmitStatus, std::hash<torch::lazy::Node const*>, std::equal_to<torch::lazy::Node const*>, std::allocator<std::pair<torch::lazy::Node const* const, torch::lazy::Util::EmitStatus> > >)
        torch_xla::XLAGraphExecutor::Compile(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, absl::lts_20230125::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorCollection const&, torch::lazy::LazyGraphExecutor::PostOrderData*, std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&)
        torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230125::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
        torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230125::Span<std::string const>, bool, bool, bool)
        torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph(torch::lazy::BackendDevice const*, c10::ArrayRef<std::string>, bool)



        PyCFunction_Call
        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        PyEval_EvalCode



        PyRun_SimpleFileExFlags
        Py_RunMain
        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***

https://symbolize.stripped_domain/r/?trace=7fe863d8b00b,7fe863d8b08f,7fe70b836517,7fe70b83657d,7fe70b4eb93f,7fe70b792fdc,7fe70b82c5ec,7fe70b82cade,7fe70b66a634,7fe70b66c6b8,7fe70b66ccfa,7fe70b66d127,7fe70b43c929,7fe70b43cd65,7fe70b41bf1f,5d5498,8fdaff&map=06b7eaee513554b0b69f7d4d65fa69f6858d5374:7fe706c02000-7fe715441e40 
*** SIGABRT received by PID 660093 (TID 660093) on cpu 24 from PID 660093; stack trace: ***
PC: @     0x7fe863d8b00b  (unknown)  raise
    @     0x7fe705f9f53a       1152  (unknown)
    @     0x7fe863d8b090  (unknown)  (unknown)
    @     0x7fe70b836518        432  ConsumeValue<>()
    @     0x7fe70b83657e         64  torch_xla::ShapeHelper::ShapeOfXlaOp()
    @     0x7fe70b4eb940        672  torch_xla::BuildDiagonalViewUpdate()
    @     0x7fe70b792fdd         80  torch_xla::DiagonalViewUpdate::Lower()
    @     0x7fe70b82c5ed        112  torch_xla::LoweringContext::LowerNode()
    @     0x7fe70b82cadf        224  torch_xla::LoweringContext::LoweringContext()
    @     0x7fe70b66a635       4192  torch_xla::XLAGraphExecutor::Compile()
    @     0x7fe70b66c6b9       1008  torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal()
    @     0x7fe70b66ccfb        560  torch_xla::XLAGraphExecutor::SyncTensorsGraph()
    @     0x7fe70b66d128       1072  torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph()
    @     0x7fe70b43c92a        720  torch_xla::(anonymous namespace)::StepMarker()
    @     0x7fe70b43cd66        128  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
    @     0x7fe70b41bf20        528  pybind11::cpp_function::dispatcher()
    @           0x5d5499  (unknown)  PyCFunction_Call
    @           0x8fdb00  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7fe863d8b00b,7fe705f9f539,7fe863d8b08f,7fe70b836517,7fe70b83657d,7fe70b4eb93f,7fe70b792fdc,7fe70b82c5ec,7fe70b82cade,7fe70b66a634,7fe70b66c6b8,7fe70b66ccfa,7fe70b66d127,7fe70b43c929,7fe70b43cd65,7fe70b41bf1f,5d5498,8fdaff&map=06b7eaee513554b0b69f7d4d65fa69f6858d5374:7fe706c02000-7fe715441e40,abbd016d9542b8098892badc0b19ea68:7fe6f8df5000-7fe7061b3cf0 
E1203 18:43:08.325543  660093 coredump_hook.cc:447] RAW: Remote crash data gathering hook invoked.
E1203 18:43:08.325562  660093 coredump_hook.cc:486] RAW: Skipping coredump since rlimit was 0 at process start.
E1203 18:43:08.325577  660093 client.cc:272] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1203 18:43:08.325588  660093 coredump_hook.cc:542] RAW: Sending fingerprint to remote end.
E1203 18:43:08.325612  660093 coredump_hook.cc:551] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E1203 18:43:08.325626  660093 coredump_hook.cc:603] RAW: Dumping core locally.
E1203 18:43:08.702843  660093 process_state.cc:783] RAW: Raising signal 6 with default behavior

Dataloader does not do anything different for the 2nd chunk of the 1st epoch.

I appreciate any help.

best regards

@mfatih7
Copy link
Author

mfatih7 commented Dec 4, 2023

Hello

I have more information about this.
In an epoch, I have separate chunk files for both training and validation in pickle format to be used in each chunk of an epoch.
I observe that in the first chunk of the first epoch, all iterations of both training and validation are completed.
During the training of the second chunk of the first epoch, the first iteration completes without any error.
But for the second iteration, I get the error above.
Each chunk passes the same operations in the dataset.
When I start the procedure from the second chunk, I observe the error again after the first iteration.
I do not think any errors regarding the pickle files since they are used in iterations I run with GPUs without any error.

Any help I appreciate.

@mfatih7
Copy link
Author

mfatih7 commented Dec 4, 2023

Hi again

Even if I run the procedure using .hdf5 files, not .pickle files I get the same error at the start of the 2nd chunk of the first epoch.

@mfatih7
Copy link
Author

mfatih7 commented Dec 5, 2023

Hello

I am also using a customized version of BatchSampler.
Is Pytorch XLA compatible with BatchSampler?

best regards

@mfatih7
Copy link
Author

mfatih7 commented Dec 5, 2023

Hello

Even if I remove BatchSampler I get the same error.

After the training of the first chunk finishes if I force only validation chunks to be processed no error occurs.
I think there is a bug for single-core TPU training.
Does the error message imply anything?

best regards

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 5, 2023

let me take a look..

@mfatih7
Copy link
Author

mfatih7 commented Dec 5, 2023

Hello @JackCaoG

Here is more information.
As I previously mentioned, dataset, sampler, and dataloader behave the same in each epoch.
There is no problem with data pumping and validation.
The problem is backpropagation.

Here is the part in the training loop:

while epoch < n_epochs:
	while chunk < n_chunks:

		for i, data in enumerate(dataloader_train):

			x_dev = data['x'].to(dev)
			y_dev = data['y'].to(dev)

			a_dev =  data['a'].to(dev)
			b_dev =  data['b'].to(dev)
			c_dev =  data['c'].to(dev)
			d_dev =  data['d'].to(dev)

			outputs = model(x_dev)
			
			loss_1 = get_loss_1( outputs, y_dev)
			
			loss_2, loss_3, _ = get_loss_2_3( outputs, a_dev, b_dev, c_dev, d_dev )
			
			if(epoch<1 and chunk<1):
				loss = loss_1
			else:
				if(cond2):
					loss = loss_1 + loss_2
				elif(cond3):
					loss = loss_1 + loss_3
					
			loss.backward()
                        optimizer.step()
			
			xm.mark_step()

In the first chunk of the first epoch, loss = loss_1 is selected and backpropagation and optimization work well.
But after that in the second chunk, either loss = loss_1 + loss_2 or loss = loss_1 + loss_3 are selected I get the error above.
I think IR recompilation occurs here.
The loop works fine on the GPUs.
get_loss_2_3() function is the function where eigh() function without the lowering is used. Can the lowering issue trigger the error?

best regards

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 5, 2023

I think it is one of the corner cases where we don't handle some of the view ops properly. Can you dump the IR and HLO using https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#common-debugging-environment-variables-combinations ? You should also run it with XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1. You can just share the last IR and last HLO, since they are most likely the graph that caused the crash.

The real error is

Transpose dimensions [0,1,-2,-1] are not a permutation of the operand dimensions (operand shape is f32[9,9]).

so I want to check the HLO and see where this transpose is from and why it is being generated(which pytorch op).

@mfatih7
Copy link
Author

mfatih7 commented Dec 5, 2023

I am working on it now.

@mfatih7
Copy link
Author

mfatih7 commented Dec 5, 2023

Hello @JackCaoG

Here are .hlo and .ir files. Is it enough?
ir.txt
hlo.txt

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 5, 2023

Ok I think it is from this line of IR

  %920 = f32[9,9]{1,0} xla::diagonal_view_update(%919, %912), xla_shape=f32[9,9]{1,0}, offset=0, dim1=-2, dim2=-1

let me see if I can find someone to take a look soon. What version of pytorch/xla you are using?

@mfatih7
Copy link
Author

mfatih7 commented Dec 5, 2023

Thank you

I have Torch 2.1.1 and Torch-XLA 2.1.0 on my TPU v2 and v3 VMs.

Is this a big update?
Can you guess the end time of this update?

best regards

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 6, 2023

2.1 is fine. Let me see if I have cycle to try to repo this week

@mfatih7
Copy link
Author

mfatih7 commented Dec 11, 2023

Hello @JackCaoG

Is there anything I can do about this problem for help?

@JackCaoG
Copy link
Collaborator

I think I understand what's the problem but I don't fully understand why, the only place that I find will trigger diagonal_view_update is

XLATensorPtr diagonal(const XLATensorPtr& input, int64_t offset, int64_t dim1,
int64_t dim2) {
auto input_shape = input->shape();
int64_t canonical_dim1 = torch::lazy::GetCanonicalDimensionIndex(
dim1, input->shape().get().rank());
int64_t canonical_dim2 = torch::lazy::GetCanonicalDimensionIndex(
dim2, input->shape().get().rank());
// See Note: [Disabling functionalization]
if (runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) {
DiagonalInfo diagonal_info;
diagonal_info.offset = offset;
diagonal_info.dim1 = canonical_dim1;
diagonal_info.dim2 = canonical_dim2;
ViewInfo view_info(ViewInfo::Type::kDiagonal, input_shape,
std::move(diagonal_info));
return input->CreateViewTensor(std::move(view_info));
}

The problem is that diagonal_view_update expect dimenion to be a positive number, but getting -1 and -2. What should happen is that GetCanonicalDimensionIndex should convert it to 0 and 1, but for some reason that doesn't happen.

Do you have an easy way for me to repo this? Ideally on a single device(no multi process) and it is something that I can just copy paste and run.

@JackCaoG
Copy link
Collaborator

one way to unblock yourself is to figure out where this update is from and manually use the positive index(This is only possible is this logic is trigger by some python code, not the C++ code). What I can tell is that it happens on a tensor with size [9, 9]. It is most likely coming from torch.diagonal somewhere.

@mfatih7
Copy link
Author

mfatih7 commented Dec 12, 2023

Hello @JackCaoG

Here is the repo for a single-core TPU run.

Please change the working directory on line in the config file.

Run the run_train_1_1_TPU_single.py for debugging. After a successful short train and val cycle we get the error.

The error occurs in operations placed in the functions in loss_functions.py during backward pass.

After your comments, I tried to find the error location in the loss_functions.py by making modifications. You can observe the modifications on loss_functions_1_to_1.py

To run training using loss_functions_1_to_1.py you can activate the line.

You can compare the loss function files. But none of the changes helped.
I think none of them changed the way of the compilation.

I am ready to make any changes that can help.

best regards.

@JackCaoG
Copy link
Collaborator

I am able to repo.. let me look into be. BTW, I am using nightly so I enabled the PT_XLA_DEBUG=1, I found that you have a lot of access to tensor before the mark_step around

                    loss_cls_val = loss_cls_val * loss_count_val + classif_loss.detach().cpu().numpy() * batch_size
                    loss_ess_val = loss_ess_val * loss_count_val + ess_loss.detach().cpu().numpy() * batch_size
                    loss_geo_val = loss_geo_val * loss_count_val + geo_loss.detach().cpu().numpy() * batch_size
                    loss_count_val = loss_count_val + batch_size
                    loss_cls_val = loss_cls_val / loss_count_val
                    loss_ess_val = loss_ess_val / loss_count_val  
                    loss_geo_val = loss_geo_val / loss_count_val

I saw

Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis:   most likely user code trying to access tensor value before mark_step
Execution Analysis: Graph Info: 
Execution Analysis:   Graph Hash: 408e5e4020bd44086a4a21d9f7786a5c
Execution Analysis:   Number of Graph Inputs: 134
Execution Analysis:   Number of Graph Outputs: 1
Execution Analysis: Python Frame Triggered Execution: 
Execution Analysis:   batch_symeig (/workspaces/dk2/repo/FeatureMatchingDebugSingle_TPUcore/17_featureMatching/loss_module/loss_functions.py:144)
Execution Analysis:   weighted_8points (/workspaces/dk2/repo/FeatureMatchingDebugSingle_TPUcore/17_featureMatching/loss_module/loss_functions.py:125)
Execution Analysis:   calculate_ess_loss_and_L2loss (/workspaces/dk2/repo/FeatureMatchingDebugSingle_TPUcore/17_featureMatching/loss_module/loss_functions.py:32)
Execution Analysis:   train_and_val (/workspaces/dk2/repo/FeatureMatchingDebugSingle_TPUcore/17_featureMatching/train_1_1_each_sample_in_single_batch_TPU_single.py:206)
Execution Analysis:   <module> (run_train_1_1_TPU_single.py:40)
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================

Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis:   most likely user code trying to access tensor value before mark_step
Execution Analysis: Graph Info: 
Execution Analysis:   Graph Hash: 9b26dc2b072b03be9c945d9900df7167
Execution Analysis:   Number of Graph Inputs: 137
Execution Analysis:   Number of Graph Outputs: 1
Execution Analysis: Python Frame Triggered Execution: 
Execution Analysis:   train_and_val (/workspaces/dk2/repo/FeatureMatchingDebugSingle_TPUcore/17_featureMatching/train_1_1_each_sample_in_single_batch_TPU_single.py:213)
Execution Analysis:   <module> (run_train_1_1_TPU_single.py:40)
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================

Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis:   most likely user code trying to access tensor value before mark_step
Execution Analysis: Graph Info: 
Execution Analysis:   Graph Hash: a8a69f42425b89f3573fe6db5d6701a
Execution Analysis:   Number of Graph Inputs: 3
Execution Analysis:   Number of Graph Outputs: 1
Execution Analysis: Python Frame Triggered Execution: 
Execution Analysis:   train_and_val (/workspaces/dk2/repo/FeatureMatchingDebugSingle_TPUcore/17_featureMatching/train_1_1_each_sample_in_single_batch_TPU_single.py:214)
Execution Analysis:   <module> (run_train_1_1_TPU_single.py:40)
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================

I think you want to just add a mark_step before accessing any of these tensors.. which will make your code runs slightly faster.

@mfatih7
Copy link
Author

mfatih7 commented Dec 12, 2023

OK, thank you.

@JackCaoG
Copy link
Collaborator

Should be fixed by #6123 you can use tmr's nightly following https://github.com/pytorch/xla#python-packages

@mfatih7
Copy link
Author

mfatih7 commented Dec 12, 2023

Thank you very much @JackCaoG

Should I wait till the merge of .whl file?

In terms of speed, do you suggest Python 3.10 or is Python 3.8 fine?

@JackCaoG
Copy link
Collaborator

Yea you should wait for it to merge and then use the nightly the day after. I don't think python version makes that much of a difference, I would just pick whichever fits your current machine environment.

@mfatih7
Copy link
Author

mfatih7 commented Dec 12, 2023

Thank you very much @JackCaoG

The other error #6048 is also very critical for me.
I cannot run multiple core operations.

I can do anything to help.

@mfatih7
Copy link
Author

mfatih7 commented Dec 13, 2023

Hello @JackCaoG

I have tested my training loop with nightly releases.
Now I do not get the error but I get the following warning.

/home/mfatih/env3_8/lib/python3.8/site-packages/torch/autograd/init.py:266: UserWarning: aten::reshape: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at ../torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass

I think there is an issue regarding the registration of autograd.
Is this a problem?

thank you.

@JackCaoG
Copy link
Collaborator

You need to install nightly for both pytorch and pytorch/xla, so

pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp310-cp310-linux_x86_64.whl
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl

@mfatih7
Copy link
Author

mfatih7 commented Dec 13, 2023

Actually I used the following commands

pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

Let me check your advice

@JackCaoG
Copy link
Collaborator

either way should work, I saw that last nightly's night succed so new nightly whl should have the fix. We just happened to also build pytorch nightly whl with the torch_xla nightly whl.

@mfatih7
Copy link
Author

mfatih7 commented Dec 13, 2023

When I install using

pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl

and make pip list I get

torch                    2.2.0
torch-xla                2.2.0+git270e1bc

and get the following error at the beginning of the execution.

/usr/bin/env /home/mfatih/env3_8/bin/python /home/mfatih/.vscode-server/extensions/ms-python.python-2023.22.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher 48037 -- /home/mfatih/17_featureMatching/run_train_1_1_TPU_single.py
Traceback (most recent call last):
File "/home/mfatih/17_featureMatching/run_train_1_1_TPU_single.py", line 3, in
import train_1_1_each_sample_in_single_batch_TPU_single
File "/home/mfatih/17_featureMatching/train_1_1_each_sample_in_single_batch_TPU_single.py", line 1, in
import torch
File "/home/mfatih/env3_8/lib/python3.8/site-packages/torch/init.py", line 237, in
from torch._C import * # noqa: F403
ImportError: libopenblas.so.0: cannot open shared object file: No such file or directory

When I install using

pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

and make pip list I get

torch                    2.2.0.dev20231213+cpu
torch-xla                2.2.0+git270e1bc
torchaudio               2.2.0.dev20231213+cpu
torchmetrics             1.2.1
torchsummary             1.5.1
torchvision              0.18.0.dev20231213+cpu

and get the warning stated above

/home/mfatih/env3_8/lib/python3.8/site-packages/torch/autograd/init.py:266: UserWarning: aten::reshape: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at ../torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass

What should I do?

thank you

@JackCaoG
Copy link
Collaborator

does this warning block training or it hang the training?

@mfatih7
Copy link
Author

mfatih7 commented Dec 13, 2023

It does not stop the training.
Training continues but since in the warning it is stated that
an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior.
I just want to ask you to be sure that backpropogation works correctly.

@JackCaoG
Copy link
Collaborator

doesn't seems like a torch_xla issue when I look into it, I tried your command and I am able to run resnet on xla. You can post an issue on PyTorch if you have concern for this error message. Seems like something new on pytorch nightly.

@mfatih7
Copy link
Author

mfatih7 commented Dec 13, 2023

What about the installation

pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl

Why does this installation not work?
I changed to cp38. Is this the problem?
I am installing it into a virtual environment with Pytorch 3.8.

@mfatih7
Copy link
Author

mfatih7 commented Dec 15, 2023

Hello @JackCaoG

I have opened an issue in the Pytorch forum as you suggested.
What else I can do to be sure that backpropagation is healthy?
I just want to note that the same code is running on GPUs without any warning.
Therefore Pytorch guys might argue that it is related to Pytorch-XLA.

@mfatih7
Copy link
Author

mfatih7 commented Dec 21, 2023

Hello @JackCaoG

What can we say about the backpropagation health now?
What is your plan for the warning?
Is another commit needed for the nightly version?

best regards

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants