diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index b848f7ce..c90d0d6d 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -17,6 +17,8 @@ properties not included in any tree mapping operations. ### Added - The ability to specify a custom `snapshot_dir` in `checkpoints_iterator`. +- `HandlerAwaitableSignal` for signalling between Checkpointing layers to enable +async directory creation. ### Fixed - Fix a bug where snapshots are not released by `wait_for_new_checkpoint` diff --git a/checkpoint/orbax/checkpoint/_src/futures/synchronization.py b/checkpoint/orbax/checkpoint/_src/futures/synchronization.py new file mode 100644 index 00000000..4647c0d8 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/futures/synchronization.py @@ -0,0 +1,69 @@ +# Copyright 2024 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Synchronization utilities for futures.""" + +import enum +import itertools +from orbax.checkpoint._src.multihost import multihost + + +class HandlerAwaitableSignal(enum.Enum): + """Defines signals that may be awaited by a `CheckpointHandler`. + + Signals may be passed from `CheckpointManager` or `Checkpointer` layers to + `CheckpointHandler or below.` + + Attributes: + STEP_DIRECTORY_CREATION: When recieved, indicates that the step directory + has been created. The handler should not attempt to write files before the + directory is created. + ITEM_DIRECTORY_CREATION: When recieved, indicates that the item directory + has been created. The handler should not attempt to write files before the + directory is created. + """ + + STEP_DIRECTORY_CREATION = 'step_directory_creation' + ITEM_DIRECTORY_CREATION = 'item_directory_creation' + + +class HandlerAwaitableSignalBarrierKeyGenerator: + """A unique barrier key generator for a `HandlerAwaitableSignal`.""" + + _operation_id_counter = itertools.count() + _operation_id = None + + @classmethod + def next_operation_id(cls) -> int: + cls._operation_id = next(cls._operation_id_counter) + return cls._operation_id + + @classmethod + def get_unique_barrier_key(cls, signal: HandlerAwaitableSignal) -> str: + """Returns a unique barrier key for the signal. + + Args: + signal: The signal to generate a barrier key for. + + Raises: + ValueError: If `_operation_id` is not initialized. + """ + if cls._operation_id is None: + raise ValueError( + '_operation_id is not initialized. Please call `next_operation_id()`' + ' first.' + ) + return multihost.unique_barrier_key( + signal.value, suffix=str(cls._operation_id) + ) diff --git a/checkpoint/orbax/checkpoint/_src/futures/synchronization_test.py b/checkpoint/orbax/checkpoint/_src/futures/synchronization_test.py new file mode 100644 index 00000000..f9b0e9bc --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/futures/synchronization_test.py @@ -0,0 +1,56 @@ +# Copyright 2024 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from orbax.checkpoint._src.futures import synchronization +from orbax.checkpoint._src.multihost import multihost + + +HandlerAwaitableSignalBarrierKeyGenerator = ( + synchronization.HandlerAwaitableSignalBarrierKeyGenerator +) + + +class HandlerAwaitableSignalBarrierKeyGeneratorTest(absltest.TestCase): + + def test_get_unique_barrier_key(self): + step_directory_creation_signal = ( + synchronization.HandlerAwaitableSignal.STEP_DIRECTORY_CREATION + ) + expected_barrier_key_0 = multihost.unique_barrier_key( + step_directory_creation_signal.value, suffix="0" + ) + expected_barrier_key_1 = multihost.unique_barrier_key( + step_directory_creation_signal.value, suffix="1" + ) + + HandlerAwaitableSignalBarrierKeyGenerator.next_operation_id() + barrier_key_0 = ( + HandlerAwaitableSignalBarrierKeyGenerator.get_unique_barrier_key( + step_directory_creation_signal + ) + ) + HandlerAwaitableSignalBarrierKeyGenerator.next_operation_id() + barrier_key_1 = ( + HandlerAwaitableSignalBarrierKeyGenerator.get_unique_barrier_key( + step_directory_creation_signal + ) + ) + + self.assertEqual(barrier_key_0, expected_barrier_key_0) + self.assertEqual(barrier_key_1, expected_barrier_key_1) + + +if __name__ == "__main__": + absltest.main()