diff --git a/test/test_persistent_cache.py b/test/test_persistent_cache.py index 56a64fb2d325..f2da1ed55e53 100644 --- a/test/test_persistent_cache.py +++ b/test/test_persistent_cache.py @@ -20,25 +20,59 @@ def run(*args, **kwargs): return run -# Basic command to generate a simple graph and perform metrics assertions -METRICS_CMD_FMT = r''' +# Command to generate a simple graph and perform correctness/metrics assertions. +# There are four methods supported by configuring the environment for the test: +# - Single-device (default): The program runs a matmul on a single XLA device. +# - Unsharded SPMD: Setting TEST_WITH_SPMD will run the matmul replicated +# across all devices. +# - Sharded SPMD: Setting TEST_WITH_SPMD and MARK_SHARDING will shard the +# computation across all devices. +# - Multiprocess: Setting TEST_WITH_MP will run the same computation across all +# devices using multiprocessing. +TEST_FMT = r''' import os import torch import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met +import torch_xla.distributed.spmd as xs import torch_xla.runtime as xr + +def test_fn(rank=None): + if rank is not None: + # In a multiprocess setting, rank will be set to the process rank. For MP, + # we need to change the cache dir for each process to avoid a race condition + # where one process loads the compilation result of another, which would + # break the metrics assertion. + os.environ['XLA_PERSISTENT_CACHE_PATH'] = \ + os.path.join(os.environ['XLA_PERSISTENT_CACHE_PATH'], str(rank)) + + t = torch.randn(16) + expected = t + t + + xt = t.to(xm.xla_device()) + if {'TEST_WITH_SPMD', 'MARK_SHARDING'} <= os.environ.keys(): + n_dev = xr.global_runtime_device_count() + mesh = xs.Mesh(range(n_dev), (n_dev,)) + xs.mark_sharding(xt, mesh, (0,)) + + s = xt + xt + xm.mark_step() + assert torch.allclose(s.cpu(), expected), \ + f'Incorrect result! expected {expected}, got {s.cpu()}' + + for counter, value in %s: + actual = met.counter_value(counter) + assert actual == value, \ + f'Unexpected value for counter {counter}: expected {value}, got {actual}' -if 'TEST_WITH_SPMD' in os.environ: - xr.use_spmd() - -t = torch.randn(4, 4).to(xm.xla_device()) -s = t @ t -xm.mark_step() - -for counter, value in %s: - actual = met.counter_value(counter) - assert actual == value, \ - f'Unexpected value for counter {counter}: expected {value}, got {actual}' +if __name__ == '__main__': + if 'TEST_WITH_MP' in os.environ: + import torch_xla.distributed.xla_multiprocessing as xmp + xmp.spawn(test_fn, start_method='fork') + else: + if 'TEST_WITH_SPMD' in os.environ: + xr.use_spmd() + test_fn() ''' @@ -52,17 +86,15 @@ class PersistentCacheTest(unittest.TestCase): """ def _run_with_metric_assertions(self, env: dict, metric_expectations: dict): - cmd = METRICS_CMD_FMT % list(metric_expectations.items()) + cmd = TEST_FMT % list(metric_expectations.items()) proc = run([sys.executable, '-c', cmd], env=env, stdout=PIPE, stderr=STDOUT) self.assertEqual(proc.returncode, 0, f'Non-zero exit code, output:\n{proc.stdout.decode()}') - def _run_tests(self, tmpdir, use_spmd=False): - cache_dir = os.path.join(tmpdir, 'cache') + @run_with_tmpdir + def test_persistent_cache(self, tmpdir): env = os.environ.copy() - env['XLA_PERSISTENT_CACHE_PATH'] = cache_dir - if use_spmd: - env['TEST_WITH_SPMD'] = '1' + env['XLA_PERSISTENT_CACHE_PATH'] = tmpdir # Use subtests to avoid having to prime the cache for each test. with self.subTest('The first attempt should miss on the persistent cache'): @@ -91,9 +123,17 @@ def _run_tests(self, tmpdir, use_spmd=False): 'PersistentCacheHit': None }) + if xr.device_type() == 'TPU': + with self.subTest('SPMD should result in a different hash'): + env['TEST_WITH_SPMD'] = '1' + self._run_with_metric_assertions(env, { + 'PersistentCacheMiss': 1, + 'PersistentCacheHit': None + }) + with self.subTest('Corrupt serialization should not be loaded'): - for fname in os.listdir(cache_dir): - with open(os.path.join(cache_dir, fname), 'wb') as f: + for fname in os.listdir(tmpdir): + with open(os.path.join(tmpdir, fname), 'wb') as f: f.write(b'') self._run_with_metric_assertions( env, { @@ -102,14 +142,43 @@ def _run_tests(self, tmpdir, use_spmd=False): 'PersistentCacheDeserializeFailure': 1 }) - @run_with_tmpdir - def test_persistent_cache(self, tmpdir): - self._run_tests(tmpdir) - @unittest.skipUnless(xr.device_type() == 'TPU', 'TPU required for SPMD') @run_with_tmpdir def test_persistent_cache_spmd(self, tmpdir): - self._run_tests(tmpdir, use_spmd=True) + env = os.environ.copy() + env.update({ + 'XLA_PERSISTENT_CACHE_PATH': tmpdir, + 'TEST_WITH_SPMD': '1', + 'MARK_SHARDING': '1', + }) + with self.subTest('Warm the cache'): + self._run_with_metric_assertions(env, { + 'PersistentCacheMiss': 1, + 'PersistentCacheHit': None, + }) + with self.subTest('Sharded computation should yield correct result'): + self._run_with_metric_assertions(env, { + 'PersistentCacheMiss': None, + 'PersistentCacheHit': 1, + }) + + @run_with_tmpdir + def test_persistent_cache_mp(self, tmpdir): + env = os.environ.copy() + env.update({ + 'XLA_PERSISTENT_CACHE_PATH': tmpdir, + 'TEST_WITH_MP': '1', + }) + with self.subTest('Warm the cache'): + self._run_with_metric_assertions(env, { + 'PersistentCacheMiss': 1, + 'PersistentCacheHit': None, + }) + with self.subTest('MP computation should yield correct result after load'): + self._run_with_metric_assertions(env, { + 'PersistentCacheMiss': None, + 'PersistentCacheHit': 1, + }) if __name__ == '__main__': diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 6b619851833f..b36c0250af2d 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -625,7 +625,8 @@ ComputationClient::ComputationPtr PjRtComputationClient::DeserializeComputation( "variable to disable persistent caching."; xla::XlaComputation computation((*hlo_modules)[0]->ToProto()); - std::vector devices = {GetDefaultDevice()}; + std::vector devices = {UseVirtualDevice() ? spmd_device_str + : GetDefaultDevice()}; return std::make_shared(std::move(computation), devices, std::move(executable)); } diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 46eef34bb152..5a8d72d04990 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -80,7 +80,7 @@ XLAGraphExecutor::ComputationCache* CreateComputationCache() { static const size_t kMaxCacheSize = runtime::sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 1024); static const bool readonlyPersistentCache = - runtime::sys_util::GetEnvBool("XLA_PERSISTENT_CACHE_RO", false); + runtime::sys_util::GetEnvBool("XLA_PERSISTENT_CACHE_READ_ONLY", false); static std::string persistentCacheDir = runtime::sys_util::GetEnvString("XLA_PERSISTENT_CACHE_PATH", ""); if (!persistentCacheDir.empty()) { @@ -94,10 +94,8 @@ XLAGraphExecutor::ComputationCache* CreateComputationCache() { runtime::GetComputationClient()->DeserializeComputation( serialization); if (!computation) return nullptr; - bool is_sharded = bridge::GetDefaultDevice()->toString() == - GetVirtualDevice().toString(); return std::make_shared( - computation, /*is_sharded=*/is_sharded); + computation, /*is_sharded=*/UseVirtualDevice()); }; return new XLAGraphExecutor::PersistentCache( kMaxCacheSize, persistentCacheDir, readonlyPersistentCache,