Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable persistent compilation caching #5804

Merged
merged 5 commits into from
Dec 7, 2023
Merged

Conversation

jonb377
Copy link
Collaborator

@jonb377 jonb377 commented Nov 15, 2023

This change enables persistent caching by combining #5800 and #5803. It uses the serialization functionality in PjRtLoadedExecutable to convert the executables to/from strings, which are written to disk by the persistent cache.

The persistent cache is enabled by setting the environment variable XLA_PERSISTENT_CACHE_PATH to the desired compilation cache path. An additional environment variable XLA_PERSISTENT_CACHE_READ_ONLY can be used to control whether the cache is readonly, which can be useful when the cache is shared across workers in an SPMD setting.

Note that the persistent cache does not perform any eviction, so it is currently up to the user to clean the cache.

@jonb377 jonb377 force-pushed the jonbolin/use-persistent-cache branch from 0307d7a to 01fabcd Compare November 15, 2023 06:21
@jonb377
Copy link
Collaborator Author

jonb377 commented Nov 15, 2023

Ah bummer, it's not running CI since I'm targeting a different branch... I was hoping to see results for GPU. I guess we'll need to wait until the others land so I can target master.

@jonb377 jonb377 mentioned this pull request Nov 15, 2023
@jonb377 jonb377 self-assigned this Nov 15, 2023


@unittest.skipUnless(xr.device_type() in {'TPU', 'GPU'},
'Device type does not support persistent caching')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not cpu?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested it, but CPU isn't supported (deserialization fails). JAX has a similar restriction: https://github.com/google/jax/blob/234be736c4cdd8da4197078278d35a6a1cde3767/tests/compilation_cache_test.py#L69C41-L69C41

Comment on lines 165 to 169
using ComputationCache =
runtime::util::AbstractCache<torch::lazy::hash_t, CachedComputation,
torch::lazy::HashReducer>;
using MemoryCache =
runtime::util::Cache<torch::lazy::hash_t, CachedComputation,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confuse about the difference between ComputationCache and MemoryCache

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ComputationCache is used for the umbrella return type, and MemoryCache and PersistentCache are subtypes of ComputationCache

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The types were split out in #5800

@jonb377 jonb377 force-pushed the jonbolin/use-persistent-cache branch 2 times, most recently from f124553 to e3ae16c Compare November 16, 2023 03:02
@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 1, 2023

@jonb377 do you still want to merge this one?

@jonb377
Copy link
Collaborator Author

jonb377 commented Dec 1, 2023

@jonb377 do you still want to merge this one?

@JackCaoG Yes, just pending reviews. I'll ping some folks.

Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jonb377 jonb377 force-pushed the jonbolin/comp-env-hash branch from 52781bc to eeee59e Compare December 5, 2023 20:54
@jonb377 jonb377 force-pushed the jonbolin/use-persistent-cache branch from e3ae16c to 3686f29 Compare December 5, 2023 20:54
@jonb377 jonb377 force-pushed the jonbolin/comp-env-hash branch from eeee59e to d2190e4 Compare December 5, 2023 20:56
@jonb377 jonb377 force-pushed the jonbolin/use-persistent-cache branch from 3686f29 to ab037fc Compare December 5, 2023 20:56
@jonb377 jonb377 force-pushed the jonbolin/comp-env-hash branch from d2190e4 to aee73ba Compare December 6, 2023 01:37
@jonb377 jonb377 force-pushed the jonbolin/use-persistent-cache branch from ab037fc to f009baa Compare December 6, 2023 01:40
Base automatically changed from jonbolin/comp-env-hash to master December 6, 2023 16:47
@jonb377 jonb377 force-pushed the jonbolin/use-persistent-cache branch from f009baa to 6041867 Compare December 6, 2023 16:49
Copy link
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

Comment on lines +58 to +59
def _single_device_test(metrics):
t = torch.randn(16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small suggestion: you can move these test cases into your test class by making them @staticmethods

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that initially, but for the ProcessPoolExecutor it complained that they weren't pickleable 🥲

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, are you sure you didn't use @classmethod? We have other tests that use @staticmethod to create pickleable test cases

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just retried the single device test:

class PersistentCacheTest(parameterized.TestCase):
  """
  Test suite to verify compilation cache across processes. Tests will run
  multiple Python subprocesses which use the XLA runtime to populate the cache
  and perform assertions on the metrics generated.
  """

  @staticmethod
  def _single_device_test(metrics):
    t = torch.randn(16)
    xt = t.to(xm.xla_device())
    _assert_correctness_and_metrics(t, xt, metrics)

It hit this:

Traceback (most recent call last):
  File "/home/ptxla/.local/lib/python3.8/site-packages/absl/testing/parameterized.py", line 321, in bound_param_test
    return test_method(self, *testcase_params)
  File "test_persistent_cache.py", line 23, in run
    f(*args, **kwargs)
  File "test_persistent_cache.py", line 103, in test_persistent_cache
    launch_method(test_fn, ({
  File "test_persistent_cache.py", line 31, in _test_spawn
    pool.submit(fn, *args).result()
  File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
  File "/usr/local/lib/python3.8/multiprocessing/queues.py", line 239, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/usr/local/lib/python3.8/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
TypeError: cannot pickle 'staticmethod' object

The _mp_test also hit this when using staticmethod... I wonder what's going wrong lol

@jonb377
Copy link
Collaborator Author

jonb377 commented Dec 6, 2023

MP GPU looks good, but single-device seems to be broken by the same issue addressed in #6023 (cc @vanbasten23). I've disabled single-device and SPMD GPU tests.

@jonb377 jonb377 merged commit fae0166 into master Dec 7, 2023
@jonb377 jonb377 deleted the jonbolin/use-persistent-cache branch December 7, 2023 00:11
jonb377 added a commit that referenced this pull request Dec 7, 2023
jonb377 added a commit that referenced this pull request Dec 8, 2023
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants