-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable persistent compilation caching #5804
Conversation
0307d7a
to
01fabcd
Compare
Ah bummer, it's not running CI since I'm targeting a different branch... I was hoping to see results for GPU. I guess we'll need to wait until the others land so I can target master. |
|
||
|
||
@unittest.skipUnless(xr.device_type() in {'TPU', 'GPU'}, | ||
'Device type does not support persistent caching') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not cpu?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested it, but CPU isn't supported (deserialization fails). JAX has a similar restriction: https://github.com/google/jax/blob/234be736c4cdd8da4197078278d35a6a1cde3767/tests/compilation_cache_test.py#L69C41-L69C41
using ComputationCache = | ||
runtime::util::AbstractCache<torch::lazy::hash_t, CachedComputation, | ||
torch::lazy::HashReducer>; | ||
using MemoryCache = | ||
runtime::util::Cache<torch::lazy::hash_t, CachedComputation, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am confuse about the difference between ComputationCache
and MemoryCache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ComputationCache
is used for the umbrella return type, and MemoryCache
and PersistentCache
are subtypes of ComputationCache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The types were split out in #5800
f124553
to
e3ae16c
Compare
@jonb377 do you still want to merge this one? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
52781bc
to
eeee59e
Compare
e3ae16c
to
3686f29
Compare
eeee59e
to
d2190e4
Compare
3686f29
to
ab037fc
Compare
d2190e4
to
aee73ba
Compare
ab037fc
to
f009baa
Compare
f009baa
to
6041867
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
def _single_device_test(metrics): | ||
t = torch.randn(16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small suggestion: you can move these test cases into your test class by making them @staticmethod
s
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried that initially, but for the ProcessPoolExecutor it complained that they weren't pickleable 🥲
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, are you sure you didn't use @classmethod
? We have other tests that use @staticmethod
to create pickleable test cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just retried the single device test:
class PersistentCacheTest(parameterized.TestCase):
"""
Test suite to verify compilation cache across processes. Tests will run
multiple Python subprocesses which use the XLA runtime to populate the cache
and perform assertions on the metrics generated.
"""
@staticmethod
def _single_device_test(metrics):
t = torch.randn(16)
xt = t.to(xm.xla_device())
_assert_correctness_and_metrics(t, xt, metrics)
It hit this:
Traceback (most recent call last):
File "/home/ptxla/.local/lib/python3.8/site-packages/absl/testing/parameterized.py", line 321, in bound_param_test
return test_method(self, *testcase_params)
File "test_persistent_cache.py", line 23, in run
f(*args, **kwargs)
File "test_persistent_cache.py", line 103, in test_persistent_cache
launch_method(test_fn, ({
File "test_persistent_cache.py", line 31, in _test_spawn
pool.submit(fn, *args).result()
File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
File "/usr/local/lib/python3.8/multiprocessing/queues.py", line 239, in _feed
obj = _ForkingPickler.dumps(obj)
File "/usr/local/lib/python3.8/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
TypeError: cannot pickle 'staticmethod' object
The _mp_test also hit this when using staticmethod... I wonder what's going wrong lol
MP GPU looks good, but single-device seems to be broken by the same issue addressed in #6023 (cc @vanbasten23). I've disabled single-device and SPMD GPU tests. |
This change enables persistent caching by combining #5800 and #5803. It uses the serialization functionality in PjRtLoadedExecutable to convert the executables to/from strings, which are written to disk by the persistent cache.
The persistent cache is enabled by setting the environment variable
XLA_PERSISTENT_CACHE_PATH
to the desired compilation cache path. An additional environment variableXLA_PERSISTENT_CACHE_READ_ONLY
can be used to control whether the cache is readonly, which can be useful when the cache is shared across workers in an SPMD setting.Note that the persistent cache does not perform any eviction, so it is currently up to the user to clean the cache.