Skip to content

Commit

Permalink
Add support for dynamic shape in dynamo (pytorch#7676)
Browse files Browse the repository at this point in the history
Co-authored-by: JackCaoG <jackcao@google.com>
  • Loading branch information
2 people authored and yitongh committed Oct 11, 2024
1 parent 6245c47 commit 3704c3b
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 9 deletions.
17 changes: 11 additions & 6 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


def _is_on_tpu():
return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU'
return xr.device_type() == 'TPU'


skipOnTpu = unittest.skipIf(_is_on_tpu(), 'Not supported on TPU')
Expand Down Expand Up @@ -333,6 +333,14 @@ def test_simple_model_with_different_input_shape(self, initialize_on_cuda):
rtol=1e-05,
atol=1e-05))

def get_loader(self, device, sample_count, batch_size=4):
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=batch_size)
loader = xu.SampleGenerator(
data=(torch.randn(batch_size, 3, 224, 224, device=device),
torch.zeros(batch_size, dtype=torch.int64, device=device)),
sample_count=sample_count)
return loader

@skipOnTpu
@parameterized.parameters(
True,
Expand All @@ -342,10 +350,7 @@ def test_resnet18(self, initialize_on_cuda):
device = self._choose_proper_device(initialize_on_cuda)
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4)
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
loader = xu.SampleGenerator(
data=(torch.randn(batch_size, 3, 224, 224, device=device),
torch.zeros(batch_size, dtype=torch.int64, device=device)),
sample_count=sample_count)
loader = self.get_loader(device, sample_count, batch_size=4)
resnet18 = torchvision.models.resnet18()
resnet18.eval()
device_resnet18 = torchvision.models.resnet18()
Expand All @@ -356,8 +361,8 @@ def test_resnet18(self, initialize_on_cuda):
xm.mark_step()
xm.wait_device_ops()
met.clear_all()
dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla')
for data, _ in loader:
dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla')
output = dynamo_resnet18(data)
output_cpu = resnet18(data.cpu())
self.assertTrue(
Expand Down
234 changes: 234 additions & 0 deletions test/dynamo/test_dynamo_dynamic_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import unittest
import sys

import torch
import torch_xla
import torchvision

from torch_xla import runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.utils.utils as xu


def _is_on_tpu():
return xr.device_type() == 'TPU'


class DynamoDynamicShapeBasicTest(unittest.TestCase):

def _get_loader(self, device, sample_count, batch_size=4):
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=batch_size)
loader = xu.SampleGenerator(
data=(torch.randn(batch_size, 3, 224, 224, device=device),
torch.zeros(batch_size, dtype=torch.int64, device=device)),
sample_count=sample_count)
return loader

def _get_linear_and_input(self, in_dim: int, out_dum: int, batch_dim: int,
device: torch.device):
dummy_linear = torch.nn.Linear(in_dim, out_dum)
dummy_linear_xla = torch.nn.Linear(in_dim, out_dum).to(device)
dummy_linear_xla.load_state_dict(dummy_linear.state_dict())
input = torch.randn(batch_dim, in_dim)
input_xla = input.to(device)
return (dummy_linear, dummy_linear_xla, input, input_xla)

def test_dynamic_shape_basic(self):
torch_xla.manual_seed(100)
device = torch_xla.device()
# model setup
dummy_linear, dummy_linear_xla, input, input_xla = self._get_linear_and_input(
10, 20, 20, device)
compiled_linear_xla = torch.compile(
dummy_linear_xla, backend="openxla", dynamic=True)
xm.wait_device_ops()
met.clear_all()

# first run
res = dummy_linear(input)
res_xla = compiled_linear_xla(input_xla)
# TPU matmul happens in bf16
torch.allclose(res, res_xla.cpu(), atol=1e-2, rtol=1e-4)
# torch.compile should be called once
self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 1)
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

# second run with different input shape
input = torch.randn(30, 10)
input_xla = input.to(device)
met.clear_all()
res = dummy_linear(input)
res_xla = compiled_linear_xla(input_xla)
torch.allclose(res, res_xla.cpu(), atol=1e-2, rtol=1e-4)
# torch.compile should not retrace but xla will recompile
self.assertNotIn('DynamoExtractCompiledGraph', met.counter_names())
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

def test_dynamic_shape_multiple_batchs(self):
torch_xla.manual_seed(100)
device = torch_xla.device()
# model setup
in_dim = 16
out_dum = 32
batch = 8
dummy_linear, dummy_linear_xla, input, input_xla = self._get_linear_and_input(
in_dim, out_dum, batch, device)
compiled_linear_xla = torch.compile(
dummy_linear_xla, backend="openxla", dynamic=True)
xm.wait_device_ops()
met.clear_all()

# first run with batch 8
res_xla = compiled_linear_xla(input_xla)
# torch.compile should be called once
xm.wait_device_ops()
self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 1)
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

# then run with batch 16
met.clear_all()
batch = 16
input_xla = torch.randn(batch, in_dim).to(device)
res_xla = compiled_linear_xla(input_xla)
# torch.compile should not retrace but xla will recompile
xm.wait_device_ops()
self.assertNotIn('DynamoExtractCompiledGraph', met.counter_names())
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

# then run with batch 32
met.clear_all()
batch = 32
input_xla = torch.randn(batch, in_dim).to(device)
res_xla = compiled_linear_xla(input_xla)
# torch.compile should not retrace but xla will recompile
xm.wait_device_ops()
self.assertNotIn('DynamoExtractCompiledGraph', met.counter_names())
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

# run with batch 8 again and make sure we don't recompile HLO
met.clear_all()
batch = 8
input = torch.randn(batch, in_dim)
input_xla = input.to(device)
res_xla = compiled_linear_xla(input_xla)
res = dummy_linear(input)
torch.allclose(res, res_xla.cpu(), atol=1e-2, rtol=1e-4)
# torch.compile should not retrace, xla also will not compile
self.assertNotIn('DynamoExtractCompiledGraph', met.counter_names())
self.assertNotIn('CompileTime', met.metric_names())
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

def test_dynamic_shape_mix_with_non_dynamic(self):
torch_xla.manual_seed(100)
device = torch_xla.device()
# model setup
in_dim = 15
out_dum = 31
out_dum_2 = 33
batch = 8
_, dummy_linear_xla, _, input_xla = self._get_linear_and_input(
in_dim, out_dum, batch, device)
dynamic_compiled_linear_xla = torch.compile(
dummy_linear_xla, backend="openxla", dynamic=True)
_, dummy_linear_xla_2, _, input_xla_2 = self._get_linear_and_input(
in_dim, out_dum_2, batch, device)
static_compiled_linear_xla = torch.compile(
dummy_linear_xla_2, backend="openxla")
xm.wait_device_ops()
met.clear_all()

# first run the dynamic compiled model
res_xla = dynamic_compiled_linear_xla(input_xla)
# torch.compile should be called once
xm.wait_device_ops()
self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 1)
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

# and then run dynamic compiled model with differnt batch size
met.clear_all()
batch = 32
input_xla = torch.randn(batch, in_dim).to(device)
res_xla = dynamic_compiled_linear_xla(input_xla)
# torch.compile should not retrace but xla will recompile
xm.wait_device_ops()
self.assertNotIn('DynamoExtractCompiledGraph', met.counter_names())
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

# now run the static compiled model
met.clear_all()
res_xla = static_compiled_linear_xla(input_xla_2)
# torch.compile should be called
xm.wait_device_ops()
self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 1)
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

# run static compiled model with different batch size, we expect the dynamo
# to retrace the model.
met.clear_all()
batch = 12
input_xla_2 = torch.randn(batch, in_dim).to(device)
res_xla = static_compiled_linear_xla(input_xla_2)
# torch.compile should be called
xm.wait_device_ops()
self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 1)
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

def test_dynamic_shape_resnet18(self):
device = torch_xla.device()

sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
loader = self._get_loader(device, sample_count, batch_size=4)
resnet18 = torchvision.models.resnet18()
resnet18.eval()
device_resnet18 = torchvision.models.resnet18()
device_resnet18.load_state_dict(resnet18.state_dict())
device_resnet18.to(device)
device_resnet18.eval()
# materalize the fake data for test purpose
xm.mark_step()
xm.wait_device_ops()
met.clear_all()
dynamo_resnet18 = torch.compile(
device_resnet18, backend='openxla', dynamic=True)
for data, _ in loader:
output = dynamo_resnet18(data)
output_cpu = resnet18(data.cpu())
# TPU has some precision issues, skipping allclose check
if not _is_on_tpu():
self.assertTrue(
torch.allclose(output_cpu, output.cpu(), rtol=1e-05, atol=1e-05))

previous_extract_compile_count = met.counter_value(
'DynamoExtractCompiledGraph')

loader_new_shape = self._get_loader(device, sample_count, batch_size=2)
for data, _ in loader_new_shape:
output_new_shape = dynamo_resnet18(data)
output_cpu_new_shape = resnet18(data.cpu())
# TPU has some precision issues, skipping allclose check
if not _is_on_tpu():
self.assertTrue(
torch.allclose(
output_cpu_new_shape,
output_new_shape.cpu(),
rtol=1e-05,
atol=1e-05))

self.assertEqual(
met.counter_value('DynamoExtractCompiledGraph'),
previous_extract_compile_count)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/dynamo/test_dynamo_integrations_util.py"
run_test "$CDIR/dynamo/test_dynamo_aliasing.py"
run_test "$CDIR/dynamo/test_dynamo.py"
run_test "$CDIR/dynamo/test_dynamo_dynamic_shape.py"
run_test "$CDIR/dynamo/test_bridge.py"
run_test "$CDIR/dynamo/test_num_output.py"
run_test "$CDIR/dynamo/test_graph_input_matcher.py"
Expand Down
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ XLA_EXPERIMENTAL=nonzero:masked_select python3 test/ds/test_dynamic_shape_models
XLA_EXPERIMENTAL=nonzero:masked_select python3 test/ds/test_dynamic_shapes.py -v
python3 test/test_autocast.py
python3 test/dynamo/test_dynamo.py
python3 test/dynamo/test_dynamo_dynamic_shape.py
python3 test/spmd/test_spmd_debugging.py
python3 test/pjrt/test_dtypes.py
python3 test/pjrt/test_dynamic_plugin_tpu.py
Expand Down
Loading

0 comments on commit 3704c3b

Please sign in to comment.