forked from pytorch/xla
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for dynamic shape in dynamo (pytorch#7676)
Co-authored-by: JackCaoG <jackcao@google.com>
- Loading branch information
Showing
5 changed files
with
313 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
import unittest | ||
import sys | ||
|
||
import torch | ||
import torch_xla | ||
import torchvision | ||
|
||
from torch_xla import runtime as xr | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.debug.metrics as met | ||
import torch_xla.utils.utils as xu | ||
|
||
|
||
def _is_on_tpu(): | ||
return xr.device_type() == 'TPU' | ||
|
||
|
||
class DynamoDynamicShapeBasicTest(unittest.TestCase): | ||
|
||
def _get_loader(self, device, sample_count, batch_size=4): | ||
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=batch_size) | ||
loader = xu.SampleGenerator( | ||
data=(torch.randn(batch_size, 3, 224, 224, device=device), | ||
torch.zeros(batch_size, dtype=torch.int64, device=device)), | ||
sample_count=sample_count) | ||
return loader | ||
|
||
def _get_linear_and_input(self, in_dim: int, out_dum: int, batch_dim: int, | ||
device: torch.device): | ||
dummy_linear = torch.nn.Linear(in_dim, out_dum) | ||
dummy_linear_xla = torch.nn.Linear(in_dim, out_dum).to(device) | ||
dummy_linear_xla.load_state_dict(dummy_linear.state_dict()) | ||
input = torch.randn(batch_dim, in_dim) | ||
input_xla = input.to(device) | ||
return (dummy_linear, dummy_linear_xla, input, input_xla) | ||
|
||
def test_dynamic_shape_basic(self): | ||
torch_xla.manual_seed(100) | ||
device = torch_xla.device() | ||
# model setup | ||
dummy_linear, dummy_linear_xla, input, input_xla = self._get_linear_and_input( | ||
10, 20, 20, device) | ||
compiled_linear_xla = torch.compile( | ||
dummy_linear_xla, backend="openxla", dynamic=True) | ||
xm.wait_device_ops() | ||
met.clear_all() | ||
|
||
# first run | ||
res = dummy_linear(input) | ||
res_xla = compiled_linear_xla(input_xla) | ||
# TPU matmul happens in bf16 | ||
torch.allclose(res, res_xla.cpu(), atol=1e-2, rtol=1e-4) | ||
# torch.compile should be called once | ||
self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 1) | ||
self.assertEqual(met.metric_data('CompileTime')[0], 1) | ||
self.assertEqual(met.metric_data('ExecuteTime')[0], 1) | ||
|
||
# second run with different input shape | ||
input = torch.randn(30, 10) | ||
input_xla = input.to(device) | ||
met.clear_all() | ||
res = dummy_linear(input) | ||
res_xla = compiled_linear_xla(input_xla) | ||
torch.allclose(res, res_xla.cpu(), atol=1e-2, rtol=1e-4) | ||
# torch.compile should not retrace but xla will recompile | ||
self.assertNotIn('DynamoExtractCompiledGraph', met.counter_names()) | ||
self.assertEqual(met.metric_data('CompileTime')[0], 1) | ||
self.assertEqual(met.metric_data('ExecuteTime')[0], 1) | ||
|
||
def test_dynamic_shape_multiple_batchs(self): | ||
torch_xla.manual_seed(100) | ||
device = torch_xla.device() | ||
# model setup | ||
in_dim = 16 | ||
out_dum = 32 | ||
batch = 8 | ||
dummy_linear, dummy_linear_xla, input, input_xla = self._get_linear_and_input( | ||
in_dim, out_dum, batch, device) | ||
compiled_linear_xla = torch.compile( | ||
dummy_linear_xla, backend="openxla", dynamic=True) | ||
xm.wait_device_ops() | ||
met.clear_all() | ||
|
||
# first run with batch 8 | ||
res_xla = compiled_linear_xla(input_xla) | ||
# torch.compile should be called once | ||
xm.wait_device_ops() | ||
self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 1) | ||
self.assertEqual(met.metric_data('CompileTime')[0], 1) | ||
self.assertEqual(met.metric_data('ExecuteTime')[0], 1) | ||
|
||
# then run with batch 16 | ||
met.clear_all() | ||
batch = 16 | ||
input_xla = torch.randn(batch, in_dim).to(device) | ||
res_xla = compiled_linear_xla(input_xla) | ||
# torch.compile should not retrace but xla will recompile | ||
xm.wait_device_ops() | ||
self.assertNotIn('DynamoExtractCompiledGraph', met.counter_names()) | ||
self.assertEqual(met.metric_data('CompileTime')[0], 1) | ||
self.assertEqual(met.metric_data('ExecuteTime')[0], 1) | ||
|
||
# then run with batch 32 | ||
met.clear_all() | ||
batch = 32 | ||
input_xla = torch.randn(batch, in_dim).to(device) | ||
res_xla = compiled_linear_xla(input_xla) | ||
# torch.compile should not retrace but xla will recompile | ||
xm.wait_device_ops() | ||
self.assertNotIn('DynamoExtractCompiledGraph', met.counter_names()) | ||
self.assertEqual(met.metric_data('CompileTime')[0], 1) | ||
self.assertEqual(met.metric_data('ExecuteTime')[0], 1) | ||
|
||
# run with batch 8 again and make sure we don't recompile HLO | ||
met.clear_all() | ||
batch = 8 | ||
input = torch.randn(batch, in_dim) | ||
input_xla = input.to(device) | ||
res_xla = compiled_linear_xla(input_xla) | ||
res = dummy_linear(input) | ||
torch.allclose(res, res_xla.cpu(), atol=1e-2, rtol=1e-4) | ||
# torch.compile should not retrace, xla also will not compile | ||
self.assertNotIn('DynamoExtractCompiledGraph', met.counter_names()) | ||
self.assertNotIn('CompileTime', met.metric_names()) | ||
self.assertEqual(met.metric_data('ExecuteTime')[0], 1) | ||
|
||
def test_dynamic_shape_mix_with_non_dynamic(self): | ||
torch_xla.manual_seed(100) | ||
device = torch_xla.device() | ||
# model setup | ||
in_dim = 15 | ||
out_dum = 31 | ||
out_dum_2 = 33 | ||
batch = 8 | ||
_, dummy_linear_xla, _, input_xla = self._get_linear_and_input( | ||
in_dim, out_dum, batch, device) | ||
dynamic_compiled_linear_xla = torch.compile( | ||
dummy_linear_xla, backend="openxla", dynamic=True) | ||
_, dummy_linear_xla_2, _, input_xla_2 = self._get_linear_and_input( | ||
in_dim, out_dum_2, batch, device) | ||
static_compiled_linear_xla = torch.compile( | ||
dummy_linear_xla_2, backend="openxla") | ||
xm.wait_device_ops() | ||
met.clear_all() | ||
|
||
# first run the dynamic compiled model | ||
res_xla = dynamic_compiled_linear_xla(input_xla) | ||
# torch.compile should be called once | ||
xm.wait_device_ops() | ||
self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 1) | ||
self.assertEqual(met.metric_data('CompileTime')[0], 1) | ||
self.assertEqual(met.metric_data('ExecuteTime')[0], 1) | ||
|
||
# and then run dynamic compiled model with differnt batch size | ||
met.clear_all() | ||
batch = 32 | ||
input_xla = torch.randn(batch, in_dim).to(device) | ||
res_xla = dynamic_compiled_linear_xla(input_xla) | ||
# torch.compile should not retrace but xla will recompile | ||
xm.wait_device_ops() | ||
self.assertNotIn('DynamoExtractCompiledGraph', met.counter_names()) | ||
self.assertEqual(met.metric_data('CompileTime')[0], 1) | ||
self.assertEqual(met.metric_data('ExecuteTime')[0], 1) | ||
|
||
# now run the static compiled model | ||
met.clear_all() | ||
res_xla = static_compiled_linear_xla(input_xla_2) | ||
# torch.compile should be called | ||
xm.wait_device_ops() | ||
self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 1) | ||
self.assertEqual(met.metric_data('CompileTime')[0], 1) | ||
self.assertEqual(met.metric_data('ExecuteTime')[0], 1) | ||
|
||
# run static compiled model with different batch size, we expect the dynamo | ||
# to retrace the model. | ||
met.clear_all() | ||
batch = 12 | ||
input_xla_2 = torch.randn(batch, in_dim).to(device) | ||
res_xla = static_compiled_linear_xla(input_xla_2) | ||
# torch.compile should be called | ||
xm.wait_device_ops() | ||
self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 1) | ||
self.assertEqual(met.metric_data('CompileTime')[0], 1) | ||
self.assertEqual(met.metric_data('ExecuteTime')[0], 1) | ||
|
||
def test_dynamic_shape_resnet18(self): | ||
device = torch_xla.device() | ||
|
||
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) | ||
loader = self._get_loader(device, sample_count, batch_size=4) | ||
resnet18 = torchvision.models.resnet18() | ||
resnet18.eval() | ||
device_resnet18 = torchvision.models.resnet18() | ||
device_resnet18.load_state_dict(resnet18.state_dict()) | ||
device_resnet18.to(device) | ||
device_resnet18.eval() | ||
# materalize the fake data for test purpose | ||
xm.mark_step() | ||
xm.wait_device_ops() | ||
met.clear_all() | ||
dynamo_resnet18 = torch.compile( | ||
device_resnet18, backend='openxla', dynamic=True) | ||
for data, _ in loader: | ||
output = dynamo_resnet18(data) | ||
output_cpu = resnet18(data.cpu()) | ||
# TPU has some precision issues, skipping allclose check | ||
if not _is_on_tpu(): | ||
self.assertTrue( | ||
torch.allclose(output_cpu, output.cpu(), rtol=1e-05, atol=1e-05)) | ||
|
||
previous_extract_compile_count = met.counter_value( | ||
'DynamoExtractCompiledGraph') | ||
|
||
loader_new_shape = self._get_loader(device, sample_count, batch_size=2) | ||
for data, _ in loader_new_shape: | ||
output_new_shape = dynamo_resnet18(data) | ||
output_cpu_new_shape = resnet18(data.cpu()) | ||
# TPU has some precision issues, skipping allclose check | ||
if not _is_on_tpu(): | ||
self.assertTrue( | ||
torch.allclose( | ||
output_cpu_new_shape, | ||
output_new_shape.cpu(), | ||
rtol=1e-05, | ||
atol=1e-05)) | ||
|
||
self.assertEqual( | ||
met.counter_value('DynamoExtractCompiledGraph'), | ||
previous_extract_compile_count) | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.