From 39abb58f3ff0f83f6225199a26a0c5379dc2cf90 Mon Sep 17 00:00:00 2001 From: Thorben Frank Date: Thu, 22 Aug 2024 14:30:33 +0200 Subject: [PATCH] update params loader --- mlff/io/checkpoint.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/mlff/io/checkpoint.py b/mlff/io/checkpoint.py index f3e98d6..1939b70 100644 --- a/mlff/io/checkpoint.py +++ b/mlff/io/checkpoint.py @@ -3,10 +3,29 @@ import os from orbax.checkpoint import PyTreeCheckpointer, Checkpointer, PyTreeCheckpointHandler +from orbax import checkpoint +import pathlib __STEP_PREFIX__: str = 'ckpt' +def load_params_from_ckpt_dir(ckpt_dir): + loaded_mngr = checkpoint.CheckpointManager( + pathlib.Path(ckpt_dir).resolve(), + item_names=('state',), + item_handlers={'state': checkpoint.StandardCheckpointHandler()}, + options=checkpoint.CheckpointManagerOptions(step_prefix="ckpt"), + ) + + mngr_state = loaded_mngr.restore( + loaded_mngr.latest_step() + ) + + state = mngr_state.get('state') + + return state['valid_params'] + + def load_state_from_ckpt_dir(ckpt_dir: str): # mngr = CheckpointManager(ckpt_dir, __CHECKPOINTERS__, options=CheckpointManagerOptions(step_prefix=__STEP_PREFIX__)) # return mngr.restore(n)['state'] @@ -27,5 +46,5 @@ def load_state_from_ckpt_dir(ckpt_dir: str): return ckptr.restore(abs_ckpt_dir / f'{__STEP_PREFIX__}_{max_step}/state', item=None) -def load_params_from_ckpt_dir(ckpt_dir: str): +def _load_params_from_ckpt_dir(ckpt_dir: str): return load_state_from_ckpt_dir(ckpt_dir)['valid_params']