Skip to content

Commit

Permalink
Improve tests and address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Dec 6, 2023
1 parent 7c3ff58 commit f009baa
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 152 deletions.
213 changes: 71 additions & 142 deletions test/test_persistent_cache.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from absl.testing import absltest, parameterized
from concurrent.futures import ProcessPoolExecutor
import functools
import os
from subprocess import run, STDOUT, PIPE
import sys
import tempfile
import unittest

import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.runtime as xr


Expand All @@ -20,168 +25,92 @@ def run(*args, **kwargs):
return run


# Command to generate a simple graph and perform correctness/metrics assertions.
# There are four methods supported by configuring the environment for the test:
# - Single-device (default): The program runs a computation on a single XLA
# device.
# - Unsharded SPMD: Setting TEST_WITH_SPMD will run the computation replicated
# across all devices.
# - Sharded SPMD: Setting TEST_WITH_SPMD and MARK_SHARDING will shard the
# computation across all devices.
# - Multiprocess: Setting TEST_WITH_MP will run the same computation across all
# devices using multiprocessing.
TEST_FMT = r'''
import os
import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
def test_fn(rank=None):
if rank is not None:
# In a multiprocess setting, rank will be set to the process rank. For MP,
# we need to change the cache dir for each process to avoid a race condition
# where one process loads the compilation result of another, which would
# break the metrics assertion.
os.environ['XLA_PERSISTENT_CACHE_PATH'] = \
os.path.join(os.environ['XLA_PERSISTENT_CACHE_PATH'], str(rank))
def _test_spawn(fn, args):
# Use a new ProcessPoolExecutor for each test to release device locks.
with ProcessPoolExecutor() as pool:
pool.submit(fn, *args).result()

t = torch.randn(16)
expected = t + t
xt = t.to(xm.xla_device())
if {'TEST_WITH_SPMD', 'MARK_SHARDING'} <= os.environ.keys():
n_dev = xr.global_runtime_device_count()
mesh = xs.Mesh(range(n_dev), (n_dev,))
xs.mark_sharding(xt, mesh, (0,))

def _assert_correctness_and_metrics(t, xt, metrics):
expected = t + t
s = xt + xt
xm.mark_step()
assert torch.allclose(s.cpu(), expected), \
f'Incorrect result! expected {expected}, got {s.cpu()}'
for counter, value in %s:
for counter, value in metrics.items():
actual = met.counter_value(counter)
assert actual == value, \
f'Unexpected value for counter {counter}: expected {value}, got {actual}'

if __name__ == '__main__':
if 'TEST_WITH_MP' in os.environ:
import torch_xla.distributed.xla_multiprocessing as xmp
xmp.spawn(test_fn, start_method='fork')
else:
if 'TEST_WITH_SPMD' in os.environ:
xr.use_spmd()
test_fn()
'''

def _mp_test(rank, metrics):
# In MP, the cache dir must be different for each process to avoid a race
# condition where one process loads the compilation result of another, which
# would break the metrics assertion.
os.environ['XLA_PERSISTENT_CACHE_PATH'] = \
os.path.join(os.environ['XLA_PERSISTENT_CACHE_PATH'], str(rank))

t = torch.randn(16)
xt = t.to(xm.xla_device())
_assert_correctness_and_metrics(t, xt, metrics)


def _single_device_test(metrics):
t = torch.randn(16)
xt = t.to(xm.xla_device())
_assert_correctness_and_metrics(t, xt, metrics)


def _spmd_replicated_test(metrics):
xr.use_spmd()
t = torch.randn(16)
xt = t.to(xm.xla_device())
_assert_correctness_and_metrics(t, xt, metrics)


@unittest.skipUnless(xr.device_type() in {'TPU', 'GPU'},
def _spmd_sharded_test(metrics):
xr.use_spmd()
t = torch.randn(16)

xt = t.to(xm.xla_device())
n_dev = xr.global_runtime_device_count()
mesh = xs.Mesh(range(n_dev), (n_dev,))
xs.mark_sharding(xt, mesh, (0,))
_assert_correctness_and_metrics(t, xt, metrics)


@absltest.skipUnless(xr.device_type() in {'TPU', 'GPU'},
'Device type does not support persistent caching')
class PersistentCacheTest(unittest.TestCase):
class PersistentCacheTest(parameterized.TestCase):
"""
Test suite to verify compilation cache across processes. Tests will run
multiple Python subprocesses which use the XLA runtime to populate the cache
and perform assertions on the metrics generated.
"""

def _run_with_metric_assertions(self, env: dict, metric_expectations: dict):
cmd = TEST_FMT % list(metric_expectations.items())
proc = run([sys.executable, '-c', cmd], env=env, stdout=PIPE, stderr=STDOUT)
self.assertEqual(proc.returncode, 0,
f'Non-zero exit code, output:\n{proc.stdout.decode()}')

@run_with_tmpdir
def test_persistent_cache(self, tmpdir):
env = os.environ.copy()
env['XLA_PERSISTENT_CACHE_PATH'] = tmpdir

# Use subtests to avoid having to prime the cache for each test.
with self.subTest('The first attempt should miss on the persistent cache'):
self._run_with_metric_assertions(env, {
'PersistentCacheMiss': 1,
'PersistentCacheHit': None
})

with self.subTest('A second run should hit the cache'):
self._run_with_metric_assertions(env, {
'PersistentCacheMiss': None,
'PersistentCacheHit': 1
})

with self.subTest('Ignored XLA flags should not impact the hash'):
env['XLA_FLAGS'] = f'--xla_dump_disable_metadata'
self._run_with_metric_assertions(env, {
'PersistentCacheMiss': None,
'PersistentCacheHit': 1
})

with self.subTest('Non-ignored LIBTPU_INIT_ARGS should impact the hash'):
env['LIBTPU_INIT_ARGS'] = '--xla_enable_async_collective_permute=true'
self._run_with_metric_assertions(env, {
'PersistentCacheMiss': 1,
'PersistentCacheHit': None
})

if xr.device_type() == 'TPU':
with self.subTest('SPMD should result in a different hash'):
env['TEST_WITH_SPMD'] = '1'
self._run_with_metric_assertions(env, {
'PersistentCacheMiss': 1,
'PersistentCacheHit': None
})

with self.subTest('Corrupt serialization should not be loaded'):
for fname in os.listdir(tmpdir):
with open(os.path.join(tmpdir, fname), 'wb') as f:
f.write(b'')
self._run_with_metric_assertions(
env, {
'PersistentCacheMiss': None,
'PersistentCacheHit': None,
'PersistentCacheDeserializeFailure': 1
})

@unittest.skipUnless(xr.device_type() == 'TPU', 'TPU required for SPMD')
@parameterized.named_parameters(
('mp', xmp.spawn, _mp_test),
('single_device', _test_spawn, _single_device_test),
('spmd_replicated', _test_spawn, _spmd_replicated_test),
('spmd_sharded', _test_spawn, _spmd_sharded_test),
)
@run_with_tmpdir
def test_persistent_cache_spmd(self, tmpdir):
env = os.environ.copy()
env.update({
'XLA_PERSISTENT_CACHE_PATH': tmpdir,
'TEST_WITH_SPMD': '1',
'MARK_SHARDING': '1',
})
with self.subTest('Warm the cache'):
self._run_with_metric_assertions(env, {
'PersistentCacheMiss': 1,
'PersistentCacheHit': None,
})
with self.subTest('Sharded computation should yield correct result'):
self._run_with_metric_assertions(env, {
'PersistentCacheMiss': None,
'PersistentCacheHit': 1,
})
def test_persistent_cache(self, launch_method, test_fn, tmpdir):
os.environ['XLA_PERSISTENT_CACHE_PATH'] = tmpdir

@run_with_tmpdir
def test_persistent_cache_mp(self, tmpdir):
env = os.environ.copy()
env.update({
'XLA_PERSISTENT_CACHE_PATH': tmpdir,
'TEST_WITH_MP': '1',
})
with self.subTest('Warm the cache'):
self._run_with_metric_assertions(env, {
'PersistentCacheMiss': 1,
'PersistentCacheHit': None,
})
with self.subTest('MP computation should yield correct result after load'):
self._run_with_metric_assertions(env, {
'PersistentCacheMiss': None,
'PersistentCacheHit': 1,
})
# Run once to warm the cache
launch_method(test_fn, ({
'PersistentCacheMiss': 1,
'PersistentCacheHit': None
},))

# The second run should hit the cache
launch_method(test_fn, ({
'PersistentCacheMiss': None,
'PersistentCacheHit': 1
},))


if __name__ == '__main__':
test = unittest.main()
test = absltest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
6 changes: 4 additions & 2 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,13 @@ class ComputationClient {
std::vector<CompileInstance> instances) = 0;

// Serialize a computation to a string.
virtual std::string SerializeComputation(ComputationPtr computation) = 0;
virtual std::string SerializeComputation(
const ComputationPtr computation) = 0;

// Deserialize a string resulting from SerializeComputation back to a
// Computation. If the deserialization fails, nullptr is returned.
virtual ComputationPtr DeserializeComputation(std::string& serialized) = 0;
virtual ComputationPtr DeserializeComputation(
const std::string& serialized) = 0;

// Returns a hash of the current compilation environment.
virtual torch::lazy::hash_t HashCompilationEnv() = 0;
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -599,15 +599,15 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
}

std::string PjRtComputationClient::SerializeComputation(
ComputationPtr computation) {
const ComputationPtr computation) {
const PjRtComputation& pjrt_computation =
dynamic_cast<const PjRtComputation&>(*computation);

return ConsumeValue(pjrt_computation.executable->SerializeExecutable());
}

ComputationClient::ComputationPtr PjRtComputationClient::DeserializeComputation(
std::string& serialized) {
const std::string& serialized) {
auto executable_or = client_->DeserializeExecutable(serialized, std::nullopt);
if (!executable_or.ok()) {
TF_LOG(WARNING) << "Failed to deserialize executable: "
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ class PjRtComputationClient : public ComputationClient {
std::vector<ComputationPtr> Compile(
std::vector<CompileInstance> instances) override;

std::string SerializeComputation(ComputationPtr computation) override;
std::string SerializeComputation(const ComputationPtr computation) override;

ComputationPtr DeserializeComputation(std::string& serialized) override;
ComputationPtr DeserializeComputation(const std::string& serialized) override;

std::vector<DataPtr> ExecuteComputation(
const Computation& computation, absl::Span<const DataPtr> arguments,
Expand Down
10 changes: 6 additions & 4 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,15 @@ XLAGraphExecutor::ComputationCache* CreateComputationCache() {
static std::string persistentCacheDir =
runtime::sys_util::GetEnvString("XLA_PERSISTENT_CACHE_PATH", "");
if (!persistentCacheDir.empty()) {
auto serialize_fn = [](auto& computation) -> std::string {
auto serialize_fn =
[](XLAGraphExecutor::ComputationCache::TypePtr computation)
-> std::string {
return runtime::GetComputationClient()->SerializeComputation(
computation->computation);
};
auto deserialize_fn = [](auto& serialization)
-> std::shared_ptr<XLAGraphExecutor::CachedComputation> {
auto computation =
auto deserialize_fn = [](std::string serialization)
-> XLAGraphExecutor::ComputationCache::TypePtr {
runtime::ComputationClient::ComputationPtr computation =
runtime::GetComputationClient()->DeserializeComputation(
serialization);
if (!computation) return nullptr;
Expand Down

0 comments on commit f009baa

Please sign in to comment.