Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xmp.spawn fails with models containing register_buffer on large TPU Pods( >=v3-128) #3068

Closed
tgisaturday opened this issue Aug 3, 2021 · 6 comments
Labels
stale Has not had recent activity

Comments

@tgisaturday
Copy link

tgisaturday commented Aug 3, 2021

🐛 Bug

Hi, I'm trying to run my code on Cloud TPU Pod and models with self.register_buffer fails on large pods (>=v3-128) with SIGSEGV. I'm not sure self.register_buffer is the only cause but models without register_buffer doesn't cause any SIGSEGV.

To Reproduce

clone this repo,
https://github.com/tgisaturday/dalle-lightning

pip install with requirements.txt

Here are example for TPU setup.

export PROJECT_ID=lgai-vision-tpu
export TPU_NAME=tpu-pod-128
export ZONE=europe-west4-a
export RUNTIME_VERSION=v2-alpha

gcloud alpha compute tpus tpu-vm create ${TPU_NAME} \
--zone ${ZONE} --project ${PROJECT_ID} --accelerator-type v3-128 \
--version ${RUNTIME_VERSION} 

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--zone ${ZONE} --project ${PROJECT_ID} --worker=all \
  --command "git clone https://github.com/tgisaturday/dalle-lightning.git"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--zone ${ZONE} --project ${PROJECT_ID} --worker=all \
  --command "pip3 install -r dalle-lightning/requirements.txt"

These two commands work:

python3 -m torch_xla.distributed.xla_dist --tpu=tpu-pod-128 --restart-tpuvm-pod-server -- python3 dalle-lightning/train_vae.py --use_tpus --model vqvae --fake_data
python3 -m torch_xla.distributed.xla_dist --tpu=tpu-pod-128 --restart-tpuvm-pod-server -- python3 dalle-lightning/train_vae.py --use_tpus --model gvqvae --fake_data

These two fails with SIGSEGV.

python3 -m torch_xla.distributed.xla_dist --tpu=tpu-pod-128 --restart-tpuvm-pod-server -- python3 dalle-lightning/train_vae.py --use_tpus --model evqvae --fake_data
python3 -m torch_xla.distributed.xla_dist --tpu=tpu-pod-128 --restart-tpuvm-pod-server -- python3 dalle-lightning/train_dalle.py --use_tpus --vae debugvae --fake_data

All four commands work fine with smaller TPU pods.

Here are error message with SIGSEGV. Both two commands shows same error.

2021-08-03 07:39:11 10.164.0.127 [14] 2021-08-03 07:39:11.524825: E tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2021-08-03 07:39:11 10.164.0.127 [14] 2021-08-03 07:39:11.524883: E tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2021-08-03 07:39:12 10.164.0.122 [0] https://symbolize.stripped_domain/r/?trace=7f5e5d1eab51,7f5e60fcf20f,7f5e5d1e2059,7ffc6440a4bf&map=5f4fb88af97be3ecacc71363136bb015b2a07119:7f5e5d1d4000-7f5e5d1f808c
2021-08-03 07:39:12 10.164.0.122 [0] *** SIGSEGV (@0x7f5d30fe2528), see gl__________25#s15 received by PID 9039 (TID 9039) on cpu 13; stack trace: ***
2021-08-03 07:39:12 10.164.0.122 [0] PC: @ 0x7f5e5d1eab51 (unknown) (unknown)
2021-08-03 07:39:12 10.164.0.122 [0] @ 0x7f5c67ee81e0 976 (unknown)
2021-08-03 07:39:12 10.164.0.122 [0] @ 0x7f5e60fcf210 146743520 (unknown)
2021-08-03 07:39:12 10.164.0.122 [0] @ 0x7f5e5d1e205a 144 GOMP_parallel
2021-08-03 07:39:12 10.164.0.122 [0] @ 0x7ffc6440a4c0 (unknown) (unknown)
2021-08-03 07:39:12 10.164.0.122 [0] https://symbolize.stripped_domain/r/?trace=7f5e5d1eab51,7f5c67ee81df,7f5e60fcf20f,7f5e5d1e2059,7ffc6440a4bf&map=5f4fb88af97be3ecacc71363136bb015b2a07119:7f5e5d1d4000-7f5e5d1f808c,ca1b7ab241ee28147b3d590cadb5dc1b:7f5c5b1e9000-7f5c6821bb20
2021-08-03 07:39:12 10.164.0.122 [0] E0803 07:39:12.689429 9039 coredump_hook.cc:292] RAW: Remote crash data gathering hook invoked.
2021-08-03 07:39:12 10.164.0.122 [0] E0803 07:39:12.689484 9039 coredump_hook.cc:384] RAW: Skipping coredump since rlimit was 0 at process start.
2021-08-03 07:39:12 10.164.0.122 [0] E0803 07:39:12.689511 9039 client.cc:222] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
2021-08-03 07:39:12 10.164.0.122 [0] E0803 07:39:12.689527 9039 coredump_hook.cc:447] RAW: Sending fingerprint to remote end.
2021-08-03 07:39:12 10.164.0.122 [0] E0803 07:39:12.689539 9039 coredump_socket.cc:124] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
2021-08-03 07:39:12 10.164.0.122 [0] E0803 07:39:12.689590 9039 coredump_hook.cc:451] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
2021-08-03 07:39:12 10.164.0.122 [0] E0803 07:39:12.689612 9039 coredump_hook.cc:525] RAW: Discarding core.
2021-08-03 07:39:12 10.164.0.122 [0] E0803 07:39:12.711387 9039 process_state.cc:771] RAW: Raising signal 11 with default behavior

Additional Information

The only difference between vqvae, gvqvae vs. evqvae is that each model uses different VectorQuantizer class in https://github.com/tgisaturday/dalle-lightning/blob/master/pl_dalle/modules/vqvae/quantize.py

I'm not 100% sure register_buffer is only problem for dalle but it also uses register_buffer here. https://github.com/tgisaturday/dalle-lightning/blob/b4164106fee0d25ad0312ff1cbc24516c78b905d/pl_dalle/models/dalle.py#L394

@tgisaturday
Copy link
Author

@zcain117 I removed register_buffer from evqvae models and SIGSEGV still happens. xmp.spawn seems to fail while building graph.

@tgisaturday
Copy link
Author

Using nn.Parameter as a workaround to replace .copy_() solved this issue in evqvae model. Still working on dalle model.
tgisaturday/dalle-lightning@53ad464

@zcain117 Is inplace copy operations is not currently supported by torch-xla?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Aug 5, 2021

copy_ is supported, looking in the implementation, it just make a copy of the other tensor and clear self tensor's metadata. It is a bit weird that this behavior will casue a crash.

@tgisaturday
Copy link
Author

copy_ is supported, looking in the implementation, it just make a copy of the other tensor and clear self tensor's metadata. It is a bit weird that this behavior will casue a crash.

Maybe certain combination of operations with copy_ causes graph building failure or OOM.

@zcain117
Copy link
Collaborator

zcain117 commented Aug 5, 2021

I guess immediately after the copy we would have at least double the amount of memory for that tensor since we store it in the intermediate var copy_value. So maybe if it's a large tensor it could lead to OOM? I think copy_value would get garbage collected after that method returns so it wouldn't be around for too long

Strange that this happens on v3-128 and not on v3-64

@stale
Copy link

stale bot commented Sep 7, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the stale Has not had recent activity label Sep 7, 2021
@stale stale bot closed this as completed Oct 2, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Has not had recent activity
Projects
None yet
Development

No branches or pull requests

3 participants