You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The nightly PyTorch XLA build (20220413 for torch, torchvision, torch_xla) gives an unexpected error when all_reduce is used together with reduced_scatter as follows
2022-04-15 06:35:51.117279: E tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc:113] during context [post-optimization]: HloModule has a mix of layout constrained and unconstrained AllReduce instructions.
2022-04-15 06:35:51.117346: F tensorflow/core/tpu/kernels/tpu_program_group.cc:86] Check failed: xla_tpu_programs.size() > 0 (0 vs. 0)
This error doesn't happen on the 20220408 build. This is likely a side effect of #3484 that removes reduce_scatter's layout pining.
ronghanghu@t1v-n-f1525942-w-0:~$ python3 test_all_gather_all_reduce_reduce_scatter.py
2022-04-15 06:57:23.387994: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-04-15 06:57:23.388052: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-04-15 06:57:33.186197: E tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc:113] during context [post-optimization]: HloModule has a mix of layout constrained and unconstrained AllReduce instructions.
2022-04-15 06:57:33.186250: F tensorflow/core/tpu/kernels/tpu_program_group.cc:86] Check failed: xla_tpu_programs.size() > 0 (0 vs. 0)
https://symbolize.stripped_domain/r/?trace=7feaf5dfd03b,7feaf5dfd0bf,7fea3683ebcf,7fea30eb3922,7fea30e71ebd,7fea30ec1db0,7fea30ec18ae,7fea2cd23ed3,7fea323651b8,7fea362f08a0,7fea362f2633,7fea36807cb1,7fea368074e0,7fea367ef8cb,7feaf5d9f608&map=a
7d53509d90e6515f49ac71c94546cafe5812b54:7fea283df000-7fea396d4e30 *** SIGABRT received by PID 302216 (TID 303572) on cpu 1 from PID 302216; stack trace: ***
...
This example tries to cover all the 3 distributed ops. However, the error is caused by all_reduce and reduce_scatter being used together (and the error persists if we remove the all_gather op).
I think it would be great to add this example to the test cases of PyTorch XLA.
Expected behavior
The error of "HloModule has a mix of layout constrained and unconstrained AllReduce instructions" should not happen when all_gather, all_reduce and reduce_scatter are used together.
Environment
Reproducible on XLA backend [CPU/TPU]: v3-8 TPU VM
torch_xla version: 20220413 nightly from tpu-vm-pt-1.10 (see Step 1 above)
Additional context
This error breaks the FSDP implementation in #3431, which often relies on all the 3 APIs (all_gather, all_reduce, and reduce_scatter) in the same program.
The error is also reproducible on today's 20220415 build of torch, torchvision, and torch_xla.
🐛 Bug
The nightly PyTorch XLA build (20220413 for torch, torchvision, torch_xla) gives an unexpected error when
all_reduce
is used together withreduced_scatter
as followsThis error doesn't happen on the 20220408 build. This is likely a side effect of #3484 that removes
reduce_scatter
's layout pining.To Reproduce
tpu-vm-pt-1.10
runtime and install20220413
version oftorch
,torchvision
, andtorch_xla
, while keeping20220408
version of libtpu (since the newer20220413
version was reported bad in PyTorch XLA.data
assignment fails when the new tensor is a different shape #3502 (comment))./home/ronghanghu/test_all_gather_all_reduce_reduce_scatter.py
below).It prints
This example tries to cover all the 3 distributed ops. However, the error is caused by
all_reduce
andreduce_scatter
being used together (and the error persists if we remove theall_gather
op).I think it would be great to add this example to the test cases of PyTorch XLA.
Expected behavior
The error of "HloModule has a mix of layout constrained and unconstrained AllReduce instructions" should not happen when
all_gather
,all_reduce
andreduce_scatter
are used together.Environment
tpu-vm-pt-1.10
(see Step 1 above)Additional context
This error breaks the FSDP implementation in #3431, which often relies on all the 3 APIs (
all_gather
,all_reduce
, andreduce_scatter
) in the same program.The error is also reproducible on today's
20220415
build oftorch
,torchvision
, andtorch_xla
.cc: @JackCaoG
The text was updated successfully, but these errors were encountered: