From 96c48c94ebcc55b09aab8e17a1df22f86abfced8 Mon Sep 17 00:00:00 2001 From: jonb377 Date: Wed, 6 Dec 2023 16:11:26 -0800 Subject: [PATCH] Enable persistent compilation caching (#5804) --- test/run_tests.sh | 1 + test/test_persistent_cache.py | 124 ++++++++++++++++++ torch_xla/csrc/runtime/computation_client.h | 9 ++ .../csrc/runtime/pjrt_computation_client.cc | 36 +++++ .../csrc/runtime/pjrt_computation_client.h | 4 + torch_xla/csrc/xla_graph_executor.cpp | 34 ++++- torch_xla/csrc/xla_graph_executor.h | 6 + 7 files changed, 211 insertions(+), 3 deletions(-) create mode 100644 test/test_persistent_cache.py diff --git a/test/run_tests.sh b/test/run_tests.sh index c3fd72572592..b5e33436e23b 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -210,6 +210,7 @@ function run_xla_op_tests3 { run_test "$CDIR/test_input_output_aliases.py" run_test "$CDIR/test_torch_distributed_xla_backend.py" run_torchrun "$CDIR/pjrt/test_torchrun.py" + run_test "$CDIR/test_persistent_cache.py" # NOTE: this line below is testing export and don't care about GPU PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$CDIR/test_core_aten_ops.py" } diff --git a/test/test_persistent_cache.py b/test/test_persistent_cache.py new file mode 100644 index 000000000000..40aef453be39 --- /dev/null +++ b/test/test_persistent_cache.py @@ -0,0 +1,124 @@ +from absl.testing import absltest, parameterized +from concurrent.futures import ProcessPoolExecutor +import functools +import os +import sys +import tempfile + +import torch +import torch_xla.core.xla_model as xm +import torch_xla.debug.metrics as met +import torch_xla.distributed.spmd as xs +import torch_xla.distributed.xla_multiprocessing as xmp +import torch_xla.runtime as xr + + +# Wrapper to manage a temporary directory for the wrapped test +def run_with_tmpdir(f): + + @functools.wraps(f) + def run(*args, **kwargs): + with tempfile.TemporaryDirectory() as tmpdir: + kwargs.setdefault('tmpdir', tmpdir) + f(*args, **kwargs) + + return run + + +def _test_spawn(fn, args): + # Use a new ProcessPoolExecutor for each test to release device locks. + with ProcessPoolExecutor() as pool: + pool.submit(fn, *args).result() + + +def _assert_correctness_and_metrics(t, xt, metrics): + expected = t + t + s = xt + xt + xm.mark_step() + assert torch.allclose(s.cpu(), expected), \ + f'Incorrect result! expected {expected}, got {s.cpu()}' + for counter, value in metrics.items(): + actual = met.counter_value(counter) + assert actual == value, \ + f'Unexpected value for counter {counter}: expected {value}, got {actual}' + + +def _mp_test(rank, metrics): + # In MP, the cache dir must be different for each process to avoid a race + # condition where one process loads the compilation result of another, which + # would break the metrics assertion. + os.environ['XLA_PERSISTENT_CACHE_PATH'] = \ + os.path.join(os.environ['XLA_PERSISTENT_CACHE_PATH'], str(rank)) + + t = torch.randn(16) + xt = t.to(xm.xla_device()) + _assert_correctness_and_metrics(t, xt, metrics) + + +def _single_device_test(metrics): + t = torch.randn(16) + xt = t.to(xm.xla_device()) + _assert_correctness_and_metrics(t, xt, metrics) + + +def _spmd_replicated_test(metrics): + xr.use_spmd() + t = torch.randn(16) + xt = t.to(xm.xla_device()) + _assert_correctness_and_metrics(t, xt, metrics) + + +def _spmd_sharded_test(metrics): + xr.use_spmd() + t = torch.randn(16) + + xt = t.to(xm.xla_device()) + n_dev = xr.global_runtime_device_count() + mesh = xs.Mesh(range(n_dev), (n_dev,)) + xs.mark_sharding(xt, mesh, (0,)) + _assert_correctness_and_metrics(t, xt, metrics) + + +@absltest.skipUnless(xr.device_type() in {'TPU', 'CUDA'}, + 'Device type does not support persistent caching') +class PersistentCacheTest(parameterized.TestCase): + """ + Test suite to verify compilation cache across processes. Tests will run + multiple Python subprocesses which use the XLA runtime to populate the cache + and perform assertions on the metrics generated. + """ + + @run_with_tmpdir + def _run_test(self, launch_method, test_fn, tmpdir): + os.environ['XLA_PERSISTENT_CACHE_PATH'] = tmpdir + + # Run once to warm the cache + launch_method(test_fn, ({ + 'PersistentCacheMiss': 1, + 'PersistentCacheHit': None + },)) + + # The second run should hit the cache + launch_method(test_fn, ({ + 'PersistentCacheMiss': None, + 'PersistentCacheHit': 1 + },)) + + def test_persistent_cache_mp(self): + self._run_test(xmp.spawn, _mp_test) + + @parameterized.named_parameters( + ('single_device', _single_device_test), + ('spmd_replicated', _spmd_replicated_test), + ('spmd_sharded', _spmd_sharded_test), + ) + @absltest.skipUnless( + xr.device_type() == 'TPU', + 'TPU required for SPMD; single-device GPU is pending #6023') + def test_persistent_cache(self, test_fn): + self._run_test(_test_spawn, test_fn) + + +if __name__ == '__main__': + test = absltest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index b97e90ffc2e0..6e87c13b193a 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -281,6 +281,15 @@ class ComputationClient { virtual std::vector Compile( std::vector instances) = 0; + // Serialize a computation to a string. + virtual std::string SerializeComputation( + const ComputationPtr computation) = 0; + + // Deserialize a string resulting from SerializeComputation back to a + // Computation. If the deserialization fails, nullptr is returned. + virtual ComputationPtr DeserializeComputation( + const std::string& serialized) = 0; + // Returns a hash of the current compilation environment. virtual torch::lazy::hash_t HashCompilationEnv() = 0; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index a53ad32a943a..871760d4802a 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -598,6 +598,42 @@ std::vector PjRtComputationClient::Compile( return computations; } +std::string PjRtComputationClient::SerializeComputation( + const ComputationPtr computation) { + const PjRtComputation& pjrt_computation = + dynamic_cast(*computation); + + return ConsumeValue(pjrt_computation.executable->SerializeExecutable()); +} + +ComputationClient::ComputationPtr PjRtComputationClient::DeserializeComputation( + const std::string& serialized) { + auto executable_or = client_->DeserializeExecutable(serialized, std::nullopt); + if (!executable_or.ok()) { + TF_LOG(WARNING) << "Failed to deserialize executable: " + << executable_or.status(); + return nullptr; + } + auto executable = std::move(*executable_or); + + auto hlo_modules = executable->GetHloModules(); + if (!hlo_modules.ok()) { + TF_LOG(WARNING) + << "Failed to retrieve HLO modules from deserialized executable"; + return nullptr; + } + XLA_CHECK(hlo_modules->size() == 1) + << "Only a single module is supported for persistent computation " + "caching. Please unset the XLA_PERSISTENT_CACHE_PATH " + "variable to disable persistent caching."; + xla::XlaComputation computation((*hlo_modules)[0]->ToProto()); + + std::vector devices = {UseVirtualDevice() ? spmd_device_str + : GetDefaultDevice()}; + return std::make_shared(std::move(computation), devices, + std::move(executable)); +} + torch::lazy::hash_t PjRtComputationClient::HashCompilationEnv() { // TODO(jonbolin): Incorporate CompileOptions into the hash. These are // deterministically generated at the moment, so they don't need to be diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 53b16a79ca58..a54b3c0b3b1f 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -54,6 +54,10 @@ class PjRtComputationClient : public ComputationClient { std::vector Compile( std::vector instances) override; + std::string SerializeComputation(const ComputationPtr computation) override; + + ComputationPtr DeserializeComputation(const std::string& serialized) override; + std::vector ExecuteComputation( const Computation& computation, absl::Span arguments, const std::string& device, diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index c19f776ea22f..ef5683da8e54 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -76,6 +76,36 @@ bool ShouldSyncIrValue(const torch::lazy::Value& ir_value) { return ir_value->op() != xla_not_supported; } +XLAGraphExecutor::ComputationCache* CreateComputationCache() { + static const size_t kMaxCacheSize = + runtime::sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 1024); + static const bool readonlyPersistentCache = + runtime::sys_util::GetEnvBool("XLA_PERSISTENT_CACHE_READ_ONLY", false); + static std::string persistentCacheDir = + runtime::sys_util::GetEnvString("XLA_PERSISTENT_CACHE_PATH", ""); + if (!persistentCacheDir.empty()) { + auto serialize_fn = + [](XLAGraphExecutor::ComputationCache::TypePtr computation) + -> std::string { + return runtime::GetComputationClient()->SerializeComputation( + computation->computation); + }; + auto deserialize_fn = [](std::string serialization) + -> XLAGraphExecutor::ComputationCache::TypePtr { + runtime::ComputationClient::ComputationPtr computation = + runtime::GetComputationClient()->DeserializeComputation( + serialization); + if (!computation) return nullptr; + return std::make_shared( + computation, /*is_sharded=*/UseVirtualDevice()); + }; + return new XLAGraphExecutor::PersistentCache( + kMaxCacheSize, persistentCacheDir, readonlyPersistentCache, + serialize_fn, deserialize_fn); + } + return new XLAGraphExecutor::MemoryCache(kMaxCacheSize); +} + } // namespace auto XLAGraphExecutor::DeviceContextArena::Get() -> DeviceContextArena* { @@ -477,9 +507,7 @@ void XLAGraphExecutor::MaybeDumpGraph(std::string name, } XLAGraphExecutor::ComputationCache* XLAGraphExecutor::GetComputationCache() { - static const size_t kMaxCacheSize = - runtime::sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 1024); - static ComputationCache* cache = new ComputationCache(kMaxCacheSize); + static ComputationCache* cache = CreateComputationCache(); return cache; } diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 90eec4012d68..e6820173d10d 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -163,8 +163,14 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { }; using ComputationCache = + runtime::util::AbstractCache; + using MemoryCache = runtime::util::Cache; + using PersistentCache = + runtime::util::PersistentCache; ComputationCache* GetComputationCache();