Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jack cao g/xrt cherrypick 0721 #5334

Merged
merged 20 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
db7c973
Update inline style code to multiline (#5291)
wonjoolee95 Jul 10, 2023
e8e66d4
Fix typo in _test.yml (#5172)
malfet Jul 10, 2023
848c00d
[SPMD][Virtual Device]All tensors should be in SPMD:0 C++ device (#5284)
JackCaoG Jul 11, 2023
07d6f7f
Revert pr https://github.com/pytorch/xla/pull/2682 (#5215)
vanbasten23 Jul 12, 2023
9a0e24d
Make README more actionable (#5262)
will-cromar Jul 12, 2023
cc4f304
[SPMD] Use xs.Mesh in test_2d_tensor_3d_mesh (#5295)
khatwanimohit Jul 12, 2023
af0e0c3
[SPMD] Add FSDP sharding for test_train_spmd_linear_model.py (#5299)
alanwaketan Jul 12, 2023
a27382a
[SPMD] Avoid recompilations in xs.mark_sharding() (#5300)
alanwaketan Jul 13, 2023
ff97427
[SPMD] Support mark_sharding on IRs (#5301)
alanwaketan Jul 13, 2023
21784ce
[SPMD] Allow dumping post optimizations hlo (#5302)
alanwaketan Jul 13, 2023
67ab975
Add `_sharded_cpu_state_dict` for distributed checkpointing (#5288)
shahyash10 Jul 14, 2023
080fdcf
Supoort unordered sharding spec correctly (#5305)
JackCaoG Jul 17, 2023
aac03da
Support unordered sharding spec for partial replication (#5316)
JackCaoG Jul 18, 2023
def08b4
Fix mismatched GPU docker image in the doc. (#5319)
vanbasten23 Jul 18, 2023
37b8518
quick refactor on _get_group_assignment (#5318)
JackCaoG Jul 18, 2023
d9c92bd
Add tf independent serialization (#5308)
qihqi Jul 19, 2023
ebb5120
Disable coverage for now (#5321)
JackCaoG Jul 19, 2023
9a99540
Enable Some input output aliasing under SPMD (#5320)
JackCaoG Jul 19, 2023
8c20cbd
Use `_sharded_cpu_state_dict` functionality to Write Items for SPMD S…
shahyash10 Jul 19, 2023
d507067
handle single tensor for method send_to_device_single (#5317)
JackCaoG Jul 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ jobs:

INC_METADATA='{"host": "github", "project": "pytorchxla", "trace_type": "LCOV", "patchset_num": 1, "change_id": '\"${CIRCLE_BUILD_NUM}\"', "owner": "cloud-tpu-pt-dev", "bug_component": "587012"}'
echo $INC_METADATA > inc_metadata.json
gsutil cp inc_metadata.json gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/metadtaa.json
gsutil cp inc_metadata.json gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/metadata.json
fi

- name: Teardown Linux
Expand Down
51 changes: 26 additions & 25 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,32 +50,33 @@ jobs:
secrets:
gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}

test-cpu-coverage:
name: "Collect CPU test coverage"
if: github.event_name == 'push' && github.event.ref == 'refs/heads/master'
uses: ./.github/workflows/_test.yml
needs: build
with:
docker-image: ${{ needs.build.outputs.docker-image }}
collect-coverage: true
timeout-minutes: 120
disable-xrt: 1
secrets:
gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}
# break by gcc version update https://github.com/pytorch/xla/commit/e7e189961bd669c33939e269c248b391fe156d38
# test-cpu-coverage:
# name: "Collect CPU test coverage"
# if: github.event_name == 'push' && github.event.ref == 'refs/heads/master'
# uses: ./.github/workflows/_test.yml
# needs: build
# with:
# docker-image: ${{ needs.build.outputs.docker-image }}
# collect-coverage: true
# timeout-minutes: 120
# disable-xrt: 1
# secrets:
# gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}

test-gpu-coverage:
name: "Collect GPU test coverage"
if: github.event_name == 'push' && github.event.ref == 'refs/heads/master'
uses: ./.github/workflows/_test.yml
needs: build
with:
docker-image: ${{ needs.build.outputs.docker-image }}
runner: linux.8xlarge.nvidia.gpu
timeout-minutes: 210
collect-coverage: true
disable-xrt: 1
secrets:
gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}
# test-gpu-coverage:
# name: "Collect GPU test coverage"
# if: github.event_name == 'push' && github.event.ref == 'refs/heads/master'
# uses: ./.github/workflows/_test.yml
# needs: build
# with:
# docker-image: ${{ needs.build.outputs.docker-image }}
# runner: linux.8xlarge.nvidia.gpu
# timeout-minutes: 210
# collect-coverage: true
# disable-xrt: 1
# secrets:
# gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}

push-docs:
name: "Build & publish docs"
Expand Down
298 changes: 181 additions & 117 deletions README.md

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,16 @@ variables:
- Release Python's GIL when transferring data from the runtime.
type: bool
default_value: true
XLA_STABLEHLO_COMPILE:
descripton:
- Pass StableHLO to XLA PjRt client for compilatoin. This compilation
flag is experimental. The default_value will be set to true when
StableHLO workflow is mature.
type: bool
default_value: false
XLA_DUMP_POST_OPTIMIZATIONS:
descripton:
- Dump the HLO graph after optimizations. You need to use it together
with XLA_SAVE_TENSORS_FMT='hlo' and XLA_SAVE_TENSORS_FILE='your/location'.
type: bool
default_value: false
13 changes: 8 additions & 5 deletions docs/gpu.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# How to run with PyTorch/XLA:GPU

PyTorch/XLA enables PyTorch users to utilize the XLA compiler which supports accelerators including TPU, GPU, CPU and This doc will go over the basic steps to run PyTorch/XLA on a nvidia gpu instance
PyTorch/XLA enables PyTorch users to utilize the XLA compiler which supports accelerators including TPU, GPU, and CPU This doc will go over the basic steps to run PyTorch/XLA on a nvidia gpu instance

## Create a GPU instance
Pytorch/XLA currently publish prebuilt docker images and wheels with cuda11.7/8 and python 3.8. We recommend users to create a GPU instance with corresponding config. For a full list of docker images and wheels, please refer to [this doc](https://github.com/pytorch/xla/tree/jackcao/gpu_doc#-available-images-and-wheels).

## Environment Setup

To create a GPU VM in Google Compute Engine, follow the [Google Cloud documentation](https://cloud.google.com/compute/docs/gpus/create-vm-with-gpus).

### Docker
```
sudo docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.7
Expand All @@ -15,15 +18,15 @@ curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add -
curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list
sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit
sudo systemctl restart docker
sudo docker run --gpus all -it -d gcr.io/tpu-pytorch/xla:nightly_3.7\8_cuda_11.2 bin/bash
sudo docker run --gpus all -it -d us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.7 bin/bash
sudo docker exec -it $(sudo docker ps | awk 'NR==2 { print $1 }') /bin/bash
```

Note that you need to restart the docker to make gpu devices visible in the docker container. After logging into the docker, you can use `nvidia-smi` to verify the device is setup correctly.

```
(pytorch) root@20ab2c7a2d06:/# nvidia-smi
Thu Dec 8 06:24:29 2022
Thu Dec 8 06:24:29 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 |
|-------------------------------+----------------------+----------------------+
Expand All @@ -35,7 +38,7 @@ Thu Dec 8 06:24:29 2022
| N/A 36C P0 38W / 300W | 0MiB / 16384MiB | 1% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
Expand Down Expand Up @@ -70,4 +73,4 @@ Epoch 1 train begin 06:12:38
| Training Device=xla:0/0 Epoch=1 Step=120 Loss=2.68816 Rate=388.35 GlobalRate=169.49 Time=06:14:09
```
## AMP (AUTOMATIC MIXED PRECISION)
AMP is very useful on GPU training and PyTorch/XLA reuse Cuda's AMP rule. You can checkout our [mnist example](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_amp.py) and [imagenet example](https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_amp.py). Note that we also used a modified version of [optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree) to avoid the additional sync between device and host.
AMP is very useful on GPU training and PyTorch/XLA reuse Cuda's AMP rule. You can checkout our [mnist example](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_amp.py) and [imagenet example](https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_amp.py). Note that we also used a modified version of [optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree) to avoid the additional sync between device and host.
11 changes: 9 additions & 2 deletions test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

MODEL_OPTS = {
'--sharding': {
'choices': ['batch', 'megatron-lm'],
'choices': ['batch', 'megatron-lm', 'fsdp'],
'nargs': '+',
'default': [],
},
Expand Down Expand Up @@ -58,7 +58,6 @@ def forward(self, x):

def train():
print('===> Preparing data..')
num_epochs = 18
lr = 0.1
train_loader = xu.SampleGenerator(
data=(torch.zeros(FLAGS.batch_size, FLAGS.input_dim),
Expand All @@ -78,6 +77,14 @@ def train():
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))

if 'fsdp' in FLAGS.sharding:
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))
print('Sharding model weights')
# Shard the weights according to their 0th dim
xs.mark_sharding(model.fc1.weight, mesh, (0, 1))
xs.mark_sharding(model.fc2.weight, mesh, (0, 1))

if 'megatron-lm' in FLAGS.sharding:
print('Sharding model weights')
# Shard the first layer's weights row-wise
Expand Down
110 changes: 81 additions & 29 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import tempfile
import unittest
import test_xla_sharding_base
Expand All @@ -14,6 +15,8 @@
create_default_global_save_plan,
)
from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner
from torch_xla.experimental._distributed_checkpoint_helpers import (
_sharded_cpu_state_dict, _CpuShards, _is_sharded_tensor)


class DistributedCheckpointTestBase(test_xla_sharding_base.XlaShardingTest):
Expand Down Expand Up @@ -54,7 +57,8 @@ def _save_and_restore(self,
model_in,
model_out,
save_planner=None,
load_planner=None):
load_planner=None,
is_sharded_cpu_state_dict=False):
"""
Checkpoint model_in using the provided save_planner and load into model_out
using the provided load_planner, and assert model_out equals model_in after
Expand All @@ -63,18 +67,22 @@ def _save_and_restore(self,
tmpdir = tempfile.mkdtemp()

# Save an unsharded model using the provided save planner
model_in_state_dict = model_in.state_dict()
if is_sharded_cpu_state_dict:
model_in_state_dict = _sharded_cpu_state_dict(model_in_state_dict)
model_out_state_dict = model_out.state_dict()
dist_cp.save_state_dict(
state_dict=model_in.state_dict(),
state_dict=model_in_state_dict,
storage_writer=dist_cp.FileSystemWriter(tmpdir),
planner=save_planner,
no_dist=True, # Single-host checkpoint doesn't require a process group
)

# Load the checkpoint using the provided load planner
for p1, p2 in zip(model_in.parameters(), model_out.parameters()):
self.assertFalse(torch.allclose(p1, p2))

dist_cp.load_state_dict(
state_dict=model_out.state_dict(),
state_dict=model_out_state_dict,
storage_reader=dist_cp.FileSystemReader(tmpdir),
planner=load_planner,
no_dist=True, # Single-host checkpoint doesn't require a process group
Expand All @@ -92,9 +100,15 @@ def test_unsharded_to_sharded(self):
# TODO(jonbolin): Enable tests for resharding into coarser meshes
@unittest.skip("View assignment with virtual device is not yet supported")
def test_sharded_to_unsharded(self):
model = self.SimpleLinear().to(xm.xla_device())
sharded_model = self._get_sharded_model()
self._save_and_restore(sharded_model, model, save_planner=SPMDSavePlanner())
for chkpt_on_cpu in [True, False]:
with self.subTest(chkpt_on_cpu):
model = self.SimpleLinear().to(xm.xla_device())
sharded_model = self._get_sharded_model()
self._save_and_restore(
sharded_model,
model,
save_planner=SPMDSavePlanner(),
is_sharded_cpu_state_dict=chkpt_on_cpu)

# TODO(jonbolin): Enable tests for resharding into coarser meshes
@unittest.skip("View assignment with virtual device is not yet supported")
Expand Down Expand Up @@ -183,15 +197,16 @@ def test_resolve_and_commit_sharded_tensor(self):

class SPMDSavePlannerTest(DistributedCheckpointTestBase):

def _get_save_planner(self, model):
def _get_save_planner(self, model, is_sharded_cpu_state_dict=False):
# Create an SPMDSavePlanner for the given model.
planner = SPMDSavePlanner()
planner.set_up_planner(model.state_dict(), True)
if not is_sharded_cpu_state_dict:
planner.set_up_planner(model.state_dict(), True)
else:
planner.set_up_planner(_sharded_cpu_state_dict(model.state_dict()), True)
return planner

def test_state_dict_separation(self):
model = self._get_sharded_model()
planner = self._get_save_planner(model)
def _planner_assertions(self, planner):
if self.n_devices > 1:
# The state_dict should be flattened and separated
self.assertCountEqual(planner.sharded_state_dict, ['fc1.weight'])
Expand All @@ -205,27 +220,46 @@ def test_state_dict_separation(self):
planner.unsharded_state_dict,
['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])

def test_local_save_plan(self):
def test_state_dict_separation(self):
model = self._get_sharded_model()
planner = self._get_save_planner(model)
plan = planner.create_local_plan()
parameter_count = len(list(model.parameters()))
self._planner_assertions(planner)

def test_save_state_dict_with_cpu_shards(self):
model = self._get_sharded_model()
planner = self._get_save_planner(model, is_sharded_cpu_state_dict=True)
self._planner_assertions(planner)
if self.n_devices > 1:
# When the model is sharded across devices, fc1.weight will result in
# self.n_devices WriteItems while all other tensors result in a single
# WriteItem.
sharded_write_items = [
wi for wi in plan.items if wi.index.fqn == 'fc1.weight'
]
self.assertEqual(self.n_devices, len(sharded_write_items))
# Every other parameter should have a single WriteItem
unsharded_write_items = [
x for x in plan.items if x not in sharded_write_items
]
self.assertEqual(parameter_count - 1, len(unsharded_write_items))
else:
self.assertTrue(
isinstance(planner.sharded_state_dict['fc1.weight'], _CpuShards))

def test_local_save_plan(self):

def _write_item_assertions(plan, n_devices, parameter_count):
if n_devices > 1:
# When the model is sharded across devices, fc1.weight will result in
# self.n_devices WriteItems while all other tensors result in a single
# WriteItem.
sharded_write_items = [
wi for wi in plan.items if wi.index.fqn == 'fc1.weight'
]
self.assertEqual(self.n_devices, len(sharded_write_items))
# Every other parameter should have a single WriteItem
unsharded_write_items = [
x for x in plan.items if x not in sharded_write_items
]
self.assertEqual(parameter_count - 1, len(unsharded_write_items))
else:
self.assertEqual(parameter_count, len(plan.items))
# If unsharded, there should be a single WriteItem per model parameter
self.assertEqual(parameter_count, len(plan.items))

for chkpt_on_cpu in [True, False]:
with self.subTest(chkpt_on_cpu):
model = self._get_sharded_model()
planner = self._get_save_planner(model, chkpt_on_cpu)
plan = planner.create_local_plan()
parameter_count = len(list(model.parameters()))
_write_item_assertions(plan, self.n_devices, parameter_count)

@unittest.skipIf(xr.global_device_count() == 1,
"Multiple devices required to shard tensors")
Expand All @@ -244,6 +278,24 @@ def test_resolve_shard_data(self):
self.assertTrue(torch.allclose(shard.data, resolved_data))


class DistributedCheckpointHelpersTest(DistributedCheckpointTestBase):

def test_sharded_cpu_state_dict(self):
model = self.SimpleLinear().to(xm.xla_device())
state_dict = model.state_dict()
sharded_cpu_state_dict = _sharded_cpu_state_dict(state_dict)
self.assertCountEqual(sharded_cpu_state_dict,
['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])
for name, param in sharded_cpu_state_dict.items():
if name == 'fc1.weight':
# _sharded_cpu_state_dict returns _CpuShards only for sharded tensors
if _is_sharded_tensor(param):
self.assertTrue(isinstance(param, _CpuShards))
else:
self.assertTrue(isinstance(param, torch.Tensor))
self.assertTrue(param.device == torch.device("cpu"))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
Loading