Skip to content

Commit

Permalink
Add documentation and python API for persistent cache (#6046)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 authored Dec 7, 2023
1 parent 39fcf8b commit a01de39
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 18 deletions.
24 changes: 24 additions & 0 deletions API_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,30 @@ tensors are always loaded back to the device they were saved from, and if
that device is unavailable the load will fail. PyTorch/XLA, like all of PyTorch,
is under active development and this behavior may change in the future.

## Compilation Caching

The XLA compiler converts the traced HLO into an executable which runs on
the devices. Compilation can be time consuming, and in cases where the HLO
doesn't change across executions, the compilation result can be persisted to
disk for reuse, significantly reducing development iteration time.

Note that if the HLO changes between executions, a recompilation will still
occur.

This is currently an experimental opt-in API, which must be activated before
any computations are executed. Initialization is done through the
`initialize_cache` API:

```python
import torch_xla.runtime as xr
xr.initialize_cache('YOUR_CACHE_PATH', readonly=False)
```

This will initialize a persistent compilation cache at the specified path. The
`readonly` parameter can be used to control whether the worker will be able to
write to the cache, which can be useful when a shared cache mount is used for
an SPMD workload.

## Further Reading

Additional documentation is available at the
Expand Down
38 changes: 22 additions & 16 deletions test/test_persistent_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,34 @@ def _assert_correctness_and_metrics(t, xt, metrics):
f'Unexpected value for counter {counter}: expected {value}, got {actual}'


def _mp_test(rank, metrics):
def _mp_test(rank, tmpdir, metrics):
# In MP, the cache dir must be different for each process to avoid a race
# condition where one process loads the compilation result of another, which
# would break the metrics assertion.
os.environ['XLA_PERSISTENT_CACHE_PATH'] = \
os.path.join(os.environ['XLA_PERSISTENT_CACHE_PATH'], str(rank))
xr.initialize_cache(os.path.join(tmpdir, str(rank)))

t = torch.randn(16)
xt = t.to(xm.xla_device())
_assert_correctness_and_metrics(t, xt, metrics)


def _single_device_test(metrics):
def _single_device_test(tmpdir, metrics):
xr.initialize_cache(tmpdir)
t = torch.randn(16)
xt = t.to(xm.xla_device())
_assert_correctness_and_metrics(t, xt, metrics)


def _spmd_replicated_test(metrics):
def _spmd_replicated_test(tmpdir, metrics):
xr.initialize_cache(tmpdir)
xr.use_spmd()
t = torch.randn(16)
xt = t.to(xm.xla_device())
_assert_correctness_and_metrics(t, xt, metrics)


def _spmd_sharded_test(metrics):
def _spmd_sharded_test(tmpdir, metrics):
xr.initialize_cache(tmpdir)
xr.use_spmd()
t = torch.randn(16)

Expand All @@ -90,19 +92,23 @@ class PersistentCacheTest(parameterized.TestCase):

@run_with_tmpdir
def _run_test(self, launch_method, test_fn, tmpdir):
os.environ['XLA_PERSISTENT_CACHE_PATH'] = tmpdir

# Run once to warm the cache
launch_method(test_fn, ({
'PersistentCacheMiss': 1,
'PersistentCacheHit': None
},))
launch_method(test_fn, (
tmpdir,
{
'PersistentCacheMiss': 1,
'PersistentCacheHit': None
},
))

# The second run should hit the cache
launch_method(test_fn, ({
'PersistentCacheMiss': None,
'PersistentCacheHit': 1
},))
launch_method(test_fn, (
tmpdir,
{
'PersistentCacheMiss': None,
'PersistentCacheHit': 1
},
))

def test_persistent_cache_mp(self):
self._run_test(xmp.spawn, _mp_test)
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,9 @@ void InitXlaModuleBindings(py::module m) {
m.def("_xla_runtime_is_initialized", []() {
return runtime::GetComputationClientIfInitialized() != nullptr;
});
m.def("_xla_computation_cache_is_initialized", []() {
return XLAGraphExecutor::Get()->IsComputationCacheInitialized();
});
m.def("_get_git_revs", []() { return GetRevisions(); });
m.def("_get_xla_tensor_dimension_size",
[](const at::Tensor& tensor, int dim) {
Expand Down
10 changes: 8 additions & 2 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,15 @@ void XLAGraphExecutor::MaybeDumpGraph(std::string name,
}
}

bool XLAGraphExecutor::IsComputationCacheInitialized() {
return computation_cache_ != nullptr;
}

XLAGraphExecutor::ComputationCache* XLAGraphExecutor::GetComputationCache() {
static ComputationCache* cache = CreateComputationCache();
return cache;
if (computation_cache_ == nullptr) {
computation_cache_ = CreateComputationCache();
}
return computation_cache_;
}

void XLAGraphExecutor::ClearPendingIrs(
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
torch::lazy::HashReducer>;

ComputationCache* GetComputationCache();
bool IsComputationCacheInitialized();

std::vector<torch::lazy::BackendDataPtr> ExecuteComputationWithBarrier(
torch::lazy::hash_t hash, const std::vector<at::IValue>& graph_inputs,
Expand Down Expand Up @@ -344,6 +345,8 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
std::shared_ptr<Async> SyncTensorsGraphInternal(
std::vector<XLATensorPtr>* tensors, absl::Span<const std::string> devices,
const SyncTensorsConfig& config, bool warm_up_cache_only = false);

ComputationCache* computation_cache_;
};

} // namespace torch_xla
Expand Down
18 changes: 18 additions & 0 deletions torch_xla/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,21 @@ def get_master_ip() -> str:
if device_type() == 'TPU':
return tpu.discover_master_worker_ip()
raise RuntimeError(f'IP discovery not supported for device: {device_type()}')


@requires_pjrt
def initialize_cache(path: str, readonly: bool = False):
"""Initializes the persistent compilation cache. This API must be called
before any computations have been performed.
Args:
path: The path at which to store the persistent cache.
readonly: Whether or not this worker should have write access to the cache.
"""
assert not torch_xla._XLAC._xla_computation_cache_is_initialized(
), "Computation cache has already been initialized"

# TODO(jonbolin): Consider moving away from environment variables to control
# the cache.
os.environ['XLA_PERSISTENT_CACHE_PATH'] = path
os.environ['XLA_PERSISTENT_CACHE_READ_ONLY'] = '1' if readonly else '0'

0 comments on commit a01de39

Please sign in to comment.