Skip to content

Commit

Permalink
Enable persistent compilation caching (#5804)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 authored and golechwierowicz committed Jan 12, 2024
1 parent 7006c01 commit 96c48c9
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 3 deletions.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/test_input_output_aliases.py"
run_test "$CDIR/test_torch_distributed_xla_backend.py"
run_torchrun "$CDIR/pjrt/test_torchrun.py"
run_test "$CDIR/test_persistent_cache.py"
# NOTE: this line below is testing export and don't care about GPU
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$CDIR/test_core_aten_ops.py"
}
Expand Down
124 changes: 124 additions & 0 deletions test/test_persistent_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from absl.testing import absltest, parameterized
from concurrent.futures import ProcessPoolExecutor
import functools
import os
import sys
import tempfile

import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.runtime as xr


# Wrapper to manage a temporary directory for the wrapped test
def run_with_tmpdir(f):

@functools.wraps(f)
def run(*args, **kwargs):
with tempfile.TemporaryDirectory() as tmpdir:
kwargs.setdefault('tmpdir', tmpdir)
f(*args, **kwargs)

return run


def _test_spawn(fn, args):
# Use a new ProcessPoolExecutor for each test to release device locks.
with ProcessPoolExecutor() as pool:
pool.submit(fn, *args).result()


def _assert_correctness_and_metrics(t, xt, metrics):
expected = t + t
s = xt + xt
xm.mark_step()
assert torch.allclose(s.cpu(), expected), \
f'Incorrect result! expected {expected}, got {s.cpu()}'
for counter, value in metrics.items():
actual = met.counter_value(counter)
assert actual == value, \
f'Unexpected value for counter {counter}: expected {value}, got {actual}'


def _mp_test(rank, metrics):
# In MP, the cache dir must be different for each process to avoid a race
# condition where one process loads the compilation result of another, which
# would break the metrics assertion.
os.environ['XLA_PERSISTENT_CACHE_PATH'] = \
os.path.join(os.environ['XLA_PERSISTENT_CACHE_PATH'], str(rank))

t = torch.randn(16)
xt = t.to(xm.xla_device())
_assert_correctness_and_metrics(t, xt, metrics)


def _single_device_test(metrics):
t = torch.randn(16)
xt = t.to(xm.xla_device())
_assert_correctness_and_metrics(t, xt, metrics)


def _spmd_replicated_test(metrics):
xr.use_spmd()
t = torch.randn(16)
xt = t.to(xm.xla_device())
_assert_correctness_and_metrics(t, xt, metrics)


def _spmd_sharded_test(metrics):
xr.use_spmd()
t = torch.randn(16)

xt = t.to(xm.xla_device())
n_dev = xr.global_runtime_device_count()
mesh = xs.Mesh(range(n_dev), (n_dev,))
xs.mark_sharding(xt, mesh, (0,))
_assert_correctness_and_metrics(t, xt, metrics)


@absltest.skipUnless(xr.device_type() in {'TPU', 'CUDA'},
'Device type does not support persistent caching')
class PersistentCacheTest(parameterized.TestCase):
"""
Test suite to verify compilation cache across processes. Tests will run
multiple Python subprocesses which use the XLA runtime to populate the cache
and perform assertions on the metrics generated.
"""

@run_with_tmpdir
def _run_test(self, launch_method, test_fn, tmpdir):
os.environ['XLA_PERSISTENT_CACHE_PATH'] = tmpdir

# Run once to warm the cache
launch_method(test_fn, ({
'PersistentCacheMiss': 1,
'PersistentCacheHit': None
},))

# The second run should hit the cache
launch_method(test_fn, ({
'PersistentCacheMiss': None,
'PersistentCacheHit': 1
},))

def test_persistent_cache_mp(self):
self._run_test(xmp.spawn, _mp_test)

@parameterized.named_parameters(
('single_device', _single_device_test),
('spmd_replicated', _spmd_replicated_test),
('spmd_sharded', _spmd_sharded_test),
)
@absltest.skipUnless(
xr.device_type() == 'TPU',
'TPU required for SPMD; single-device GPU is pending #6023')
def test_persistent_cache(self, test_fn):
self._run_test(_test_spawn, test_fn)


if __name__ == '__main__':
test = absltest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
9 changes: 9 additions & 0 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,15 @@ class ComputationClient {
virtual std::vector<ComputationPtr> Compile(
std::vector<CompileInstance> instances) = 0;

// Serialize a computation to a string.
virtual std::string SerializeComputation(
const ComputationPtr computation) = 0;

// Deserialize a string resulting from SerializeComputation back to a
// Computation. If the deserialization fails, nullptr is returned.
virtual ComputationPtr DeserializeComputation(
const std::string& serialized) = 0;

// Returns a hash of the current compilation environment.
virtual torch::lazy::hash_t HashCompilationEnv() = 0;

Expand Down
36 changes: 36 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,42 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
return computations;
}

std::string PjRtComputationClient::SerializeComputation(
const ComputationPtr computation) {
const PjRtComputation& pjrt_computation =
dynamic_cast<const PjRtComputation&>(*computation);

return ConsumeValue(pjrt_computation.executable->SerializeExecutable());
}

ComputationClient::ComputationPtr PjRtComputationClient::DeserializeComputation(
const std::string& serialized) {
auto executable_or = client_->DeserializeExecutable(serialized, std::nullopt);
if (!executable_or.ok()) {
TF_LOG(WARNING) << "Failed to deserialize executable: "
<< executable_or.status();
return nullptr;
}
auto executable = std::move(*executable_or);

auto hlo_modules = executable->GetHloModules();
if (!hlo_modules.ok()) {
TF_LOG(WARNING)
<< "Failed to retrieve HLO modules from deserialized executable";
return nullptr;
}
XLA_CHECK(hlo_modules->size() == 1)
<< "Only a single module is supported for persistent computation "
"caching. Please unset the XLA_PERSISTENT_CACHE_PATH "
"variable to disable persistent caching.";
xla::XlaComputation computation((*hlo_modules)[0]->ToProto());

std::vector<std::string> devices = {UseVirtualDevice() ? spmd_device_str
: GetDefaultDevice()};
return std::make_shared<PjRtComputation>(std::move(computation), devices,
std::move(executable));
}

torch::lazy::hash_t PjRtComputationClient::HashCompilationEnv() {
// TODO(jonbolin): Incorporate CompileOptions into the hash. These are
// deterministically generated at the moment, so they don't need to be
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ class PjRtComputationClient : public ComputationClient {
std::vector<ComputationPtr> Compile(
std::vector<CompileInstance> instances) override;

std::string SerializeComputation(const ComputationPtr computation) override;

ComputationPtr DeserializeComputation(const std::string& serialized) override;

std::vector<DataPtr> ExecuteComputation(
const Computation& computation, absl::Span<const DataPtr> arguments,
const std::string& device,
Expand Down
34 changes: 31 additions & 3 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,36 @@ bool ShouldSyncIrValue(const torch::lazy::Value& ir_value) {
return ir_value->op() != xla_not_supported;
}

XLAGraphExecutor::ComputationCache* CreateComputationCache() {
static const size_t kMaxCacheSize =
runtime::sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 1024);
static const bool readonlyPersistentCache =
runtime::sys_util::GetEnvBool("XLA_PERSISTENT_CACHE_READ_ONLY", false);
static std::string persistentCacheDir =
runtime::sys_util::GetEnvString("XLA_PERSISTENT_CACHE_PATH", "");
if (!persistentCacheDir.empty()) {
auto serialize_fn =
[](XLAGraphExecutor::ComputationCache::TypePtr computation)
-> std::string {
return runtime::GetComputationClient()->SerializeComputation(
computation->computation);
};
auto deserialize_fn = [](std::string serialization)
-> XLAGraphExecutor::ComputationCache::TypePtr {
runtime::ComputationClient::ComputationPtr computation =
runtime::GetComputationClient()->DeserializeComputation(
serialization);
if (!computation) return nullptr;
return std::make_shared<XLAGraphExecutor::CachedComputation>(
computation, /*is_sharded=*/UseVirtualDevice());
};
return new XLAGraphExecutor::PersistentCache(
kMaxCacheSize, persistentCacheDir, readonlyPersistentCache,
serialize_fn, deserialize_fn);
}
return new XLAGraphExecutor::MemoryCache(kMaxCacheSize);
}

} // namespace

auto XLAGraphExecutor::DeviceContextArena::Get() -> DeviceContextArena* {
Expand Down Expand Up @@ -477,9 +507,7 @@ void XLAGraphExecutor::MaybeDumpGraph(std::string name,
}

XLAGraphExecutor::ComputationCache* XLAGraphExecutor::GetComputationCache() {
static const size_t kMaxCacheSize =
runtime::sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 1024);
static ComputationCache* cache = new ComputationCache(kMaxCacheSize);
static ComputationCache* cache = CreateComputationCache();
return cache;
}

Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,14 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
};

using ComputationCache =
runtime::util::AbstractCache<torch::lazy::hash_t, CachedComputation,
torch::lazy::HashReducer>;
using MemoryCache =
runtime::util::Cache<torch::lazy::hash_t, CachedComputation,
torch::lazy::HashReducer>;
using PersistentCache =
runtime::util::PersistentCache<torch::lazy::hash_t, CachedComputation,
torch::lazy::HashReducer>;

ComputationCache* GetComputationCache();

Expand Down

0 comments on commit 96c48c9

Please sign in to comment.