How do I load a pre-trained model? #403
-
There is a notebook that explains how to save and load models (https://github.com/google/brax/blob/main/notebooks/training.ipynb) but there testing happens right after training, calling function My question is how I can test without running the training process, simply by loading the params. How can I have |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 6 replies
-
Hi @eleninisioti ! This is a great feature request that we haven't gotten around to implement – A model saver that saves the parameter config alongside the weights. Here is an example of what the loader would do (where def make_inference_fn(
observation_size: int,
action_size: int,
normalize_observations: bool = True,
network_factory_kwargs: Optional[Dict[str, Any]] = None,
):
normalize = lambda x, y: x
if normalize_observations:
normalize = running_statistics.normalize
ppo_network = brax_networks.make_ppo_networks(
observation_size,
action_size,
preprocess_observations_fn=normalize,
**(network_factory_kwargs or {}),
)
make_policy = brax_networks.make_inference_fn(ppo_network)
return make_policy
config_dict = load_config_dict(checkpoint_path)
make_policy = make_inference_fn(
config_dict['observation_size'],
config_dict['action_size'],
config_dict['normalize_observations'],
network_factory_kwargs=config_dict['network_factory_kwargs'], # {"policy_hidden_layer_sizes": (128,) * 4}
)
params = model.load_params(checkpoint_file)
jit_inference_fn = jax.jit(make_policy(params, deterministic=True)) |
Beta Was this translation helpful? Give feedback.
-
Hi! |
Beta Was this translation helpful? Give feedback.
-
I'm a bit of a newb here... I haven't quite figured out yet how to resume training from some previously trained model params that I would load at the start of the experiment; how can I pass in pre-trained model params before running the train_fn function? |
Beta Was this translation helpful? Give feedback.
-
Please see #438 This makes checkpointing available with minimal changes to the public API. |
Beta Was this translation helpful? Give feedback.
Hi @eleninisioti ! This is a great feature request that we haven't gotten around to implement – A model saver that saves the parameter config alongside the weights.
Here is an example of what the loader would do (where
load_config_dict
is to-be-implemented, and the whole thing needs to be wrapped in a nice function):