Skip to content

Commit

Permalink
Add MP tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Nov 16, 2023
1 parent 8f3519b commit f124553
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 31 deletions.
121 changes: 95 additions & 26 deletions test/test_persistent_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,59 @@ def run(*args, **kwargs):
return run


# Basic command to generate a simple graph and perform metrics assertions
METRICS_CMD_FMT = r'''
# Command to generate a simple graph and perform correctness/metrics assertions.
# There are four methods supported by configuring the environment for the test:
# - Single-device (default): The program runs a matmul on a single XLA device.
# - Unsharded SPMD: Setting TEST_WITH_SPMD will run the matmul replicated
# across all devices.
# - Sharded SPMD: Setting TEST_WITH_SPMD and MARK_SHARDING will shard the
# computation across all devices.
# - Multiprocess: Setting TEST_WITH_MP will run the same computation across all
# devices using multiprocessing.
TEST_FMT = r'''
import os
import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
def test_fn(rank=None):
if rank is not None:
# In a multiprocess setting, rank will be set to the process rank. For MP,
# we need to change the cache dir for each process to avoid a race condition
# where one process loads the compilation result of another, which would
# break the metrics assertion.
os.environ['XLA_PERSISTENT_CACHE_PATH'] = \
os.path.join(os.environ['XLA_PERSISTENT_CACHE_PATH'], str(rank))
t = torch.randn(16)
expected = t + t
xt = t.to(xm.xla_device())
if {'TEST_WITH_SPMD', 'MARK_SHARDING'} <= os.environ.keys():
n_dev = xr.global_runtime_device_count()
mesh = xs.Mesh(range(n_dev), (n_dev,))
xs.mark_sharding(xt, mesh, (0,))
s = xt + xt
xm.mark_step()
assert torch.allclose(s.cpu(), expected), \
f'Incorrect result! expected {expected}, got {s.cpu()}'
for counter, value in %s:
actual = met.counter_value(counter)
assert actual == value, \
f'Unexpected value for counter {counter}: expected {value}, got {actual}'
if 'TEST_WITH_SPMD' in os.environ:
xr.use_spmd()
t = torch.randn(4, 4).to(xm.xla_device())
s = t @ t
xm.mark_step()
for counter, value in %s:
actual = met.counter_value(counter)
assert actual == value, \
f'Unexpected value for counter {counter}: expected {value}, got {actual}'
if __name__ == '__main__':
if 'TEST_WITH_MP' in os.environ:
import torch_xla.distributed.xla_multiprocessing as xmp
xmp.spawn(test_fn, start_method='fork')
else:
if 'TEST_WITH_SPMD' in os.environ:
xr.use_spmd()
test_fn()
'''


Expand All @@ -52,17 +86,15 @@ class PersistentCacheTest(unittest.TestCase):
"""

def _run_with_metric_assertions(self, env: dict, metric_expectations: dict):
cmd = METRICS_CMD_FMT % list(metric_expectations.items())
cmd = TEST_FMT % list(metric_expectations.items())
proc = run([sys.executable, '-c', cmd], env=env, stdout=PIPE, stderr=STDOUT)
self.assertEqual(proc.returncode, 0,
f'Non-zero exit code, output:\n{proc.stdout.decode()}')

def _run_tests(self, tmpdir, use_spmd=False):
cache_dir = os.path.join(tmpdir, 'cache')
@run_with_tmpdir
def test_persistent_cache(self, tmpdir):
env = os.environ.copy()
env['XLA_PERSISTENT_CACHE_PATH'] = cache_dir
if use_spmd:
env['TEST_WITH_SPMD'] = '1'
env['XLA_PERSISTENT_CACHE_PATH'] = tmpdir

# Use subtests to avoid having to prime the cache for each test.
with self.subTest('The first attempt should miss on the persistent cache'):
Expand Down Expand Up @@ -91,9 +123,17 @@ def _run_tests(self, tmpdir, use_spmd=False):
'PersistentCacheHit': None
})

if xr.device_type() == 'TPU':
with self.subTest('SPMD should result in a different hash'):
env['TEST_WITH_SPMD'] = '1'
self._run_with_metric_assertions(env, {
'PersistentCacheMiss': 1,
'PersistentCacheHit': None
})

with self.subTest('Corrupt serialization should not be loaded'):
for fname in os.listdir(cache_dir):
with open(os.path.join(cache_dir, fname), 'wb') as f:
for fname in os.listdir(tmpdir):
with open(os.path.join(tmpdir, fname), 'wb') as f:
f.write(b'')
self._run_with_metric_assertions(
env, {
Expand All @@ -102,14 +142,43 @@ def _run_tests(self, tmpdir, use_spmd=False):
'PersistentCacheDeserializeFailure': 1
})

@run_with_tmpdir
def test_persistent_cache(self, tmpdir):
self._run_tests(tmpdir)

@unittest.skipUnless(xr.device_type() == 'TPU', 'TPU required for SPMD')
@run_with_tmpdir
def test_persistent_cache_spmd(self, tmpdir):
self._run_tests(tmpdir, use_spmd=True)
env = os.environ.copy()
env.update({
'XLA_PERSISTENT_CACHE_PATH': tmpdir,
'TEST_WITH_SPMD': '1',
'MARK_SHARDING': '1',
})
with self.subTest('Warm the cache'):
self._run_with_metric_assertions(env, {
'PersistentCacheMiss': 1,
'PersistentCacheHit': None,
})
with self.subTest('Sharded computation should yield correct result'):
self._run_with_metric_assertions(env, {
'PersistentCacheMiss': None,
'PersistentCacheHit': 1,
})

@run_with_tmpdir
def test_persistent_cache_mp(self, tmpdir):
env = os.environ.copy()
env.update({
'XLA_PERSISTENT_CACHE_PATH': tmpdir,
'TEST_WITH_MP': '1',
})
with self.subTest('Warm the cache'):
self._run_with_metric_assertions(env, {
'PersistentCacheMiss': 1,
'PersistentCacheHit': None,
})
with self.subTest('MP computation should yield correct result after load'):
self._run_with_metric_assertions(env, {
'PersistentCacheMiss': None,
'PersistentCacheHit': 1,
})


if __name__ == '__main__':
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,8 @@ ComputationClient::ComputationPtr PjRtComputationClient::DeserializeComputation(
"variable to disable persistent caching.";
xla::XlaComputation computation((*hlo_modules)[0]->ToProto());

std::vector<std::string> devices = {GetDefaultDevice()};
std::vector<std::string> devices = {UseVirtualDevice() ? spmd_device_str
: GetDefaultDevice()};
return std::make_shared<PjRtComputation>(std::move(computation), devices,
std::move(executable));
}
Expand Down
6 changes: 2 additions & 4 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ XLAGraphExecutor::ComputationCache* CreateComputationCache() {
static const size_t kMaxCacheSize =
runtime::sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 1024);
static const bool readonlyPersistentCache =
runtime::sys_util::GetEnvBool("XLA_PERSISTENT_CACHE_RO", false);
runtime::sys_util::GetEnvBool("XLA_PERSISTENT_CACHE_READ_ONLY", false);
static std::string persistentCacheDir =
runtime::sys_util::GetEnvString("XLA_PERSISTENT_CACHE_PATH", "");
if (!persistentCacheDir.empty()) {
Expand All @@ -94,10 +94,8 @@ XLAGraphExecutor::ComputationCache* CreateComputationCache() {
runtime::GetComputationClient()->DeserializeComputation(
serialization);
if (!computation) return nullptr;
bool is_sharded = bridge::GetDefaultDevice()->toString() ==
GetVirtualDevice().toString();
return std::make_shared<XLAGraphExecutor::CachedComputation>(
computation, /*is_sharded=*/is_sharded);
computation, /*is_sharded=*/UseVirtualDevice());
};
return new XLAGraphExecutor::PersistentCache(
kMaxCacheSize, persistentCacheDir, readonlyPersistentCache,
Expand Down

0 comments on commit f124553

Please sign in to comment.