diff --git a/mlff/CLI/run_fine_tuning.py b/mlff/CLI/run_fine_tuning.py new file mode 100644 index 0000000..3aeb698 --- /dev/null +++ b/mlff/CLI/run_fine_tuning.py @@ -0,0 +1,39 @@ +import argparse +import json +from mlff.config import from_config +from ml_collections import config_dict +import pathlib +import yaml + + +def fine_tune_so3krates_sparse(): + # Create the parser + parser = argparse.ArgumentParser(description='Train a SO3kratesSparse model.') + parser.add_argument( + '--config', + type=str, + required=True, + help='Path to the config file.' + ) + parser.add_argument( + '--start_from_workdir', + type=str, + required=True, + help='Path to workdir from which fine tuning should be started.' + ) + + args = parser.parse_args() + + config = pathlib.Path(args.config).expanduser().absolute().resolve() + if config.suffix == '.json': + with open(config, mode='r') as fp: + cfg = config_dict.ConfigDict(json.load(fp=fp)) + elif config.suffix == '.yaml': + with open(config, mode='r') as fp: + cfg = config_dict.ConfigDict(yaml.load(fp, Loader=yaml.FullLoader)) + + from_config.run_fine_tuning(cfg, start_from_workdir=args.start_from_workdir) + + +if __name__ == '__main__': + fine_tune_so3krates_sparse() diff --git a/mlff/config/from_config.py b/mlff/config/from_config.py index 16d34dd..c86a60a 100644 --- a/mlff/config/from_config.py +++ b/mlff/config/from_config.py @@ -366,3 +366,245 @@ def run_evaluation( batch_max_num_graphs=config.training.batch_max_num_graphs, write_batch_metrics_to=write_batch_metrics_to ) + + +def run_fine_tuning( + config: config_dict.ConfigDict, + start_from_workdir: str +): + """Run training given a config. + + Args: + config (): The config. + start_from_workdir (str): The workdir of the model that should be fine tuned. + + Returns: + + """ + workdir = Path(config.workdir).expanduser().resolve() + if workdir.exists(): + raise ValueError( + f'Please specify new workdir for fine tuning. Workdir {workdir} already exists.' + ) + workdir.mkdir(exist_ok=False) + + start_from_workdir = Path(start_from_workdir).expanduser().resolve() + if not start_from_workdir.exists(): + raise ValueError( + f'Trying to start fine tuning from {start_from_workdir} but directory does not exist.' + ) + params = load_params_from_workdir(start_from_workdir) + + data_filepath = config.data.filepath + data_filepath = Path(data_filepath).expanduser().resolve() + + if data_filepath.suffix == '.npz': + loader = data.NpzDataLoaderSparse(input_file=data_filepath) + elif data_filepath.stem[:5].lower() == 'spice': + logging.mlff(f'Found SPICE dataset at {data_filepath}.') + if data_filepath.suffix != '.hdf5': + raise ValueError( + f'Loader assumes that SPICE is in hdf5 format. Found {data_filepath.suffix} as' + f'suffix.') + loader = data.SpiceDataLoaderSparse(input_file=data_filepath) + else: + loader = data.AseDataLoaderSparse(input_file=data_filepath) + + # Get the units of the data. + energy_unit = eval(config.data.energy_unit) + length_unit = eval(config.data.length_unit) + + # Get the total number of data points. + num_data = loader.cardinality() + num_train = config.training.num_train + num_valid = config.training.num_valid + + if num_train + num_valid > num_data: + raise ValueError(f"num_train + num_valid = {num_train + num_valid} exceeds the number of data points {num_data}" + f" in {data_filepath}.") + + split_seed = config.data.split_seed + numpy_rng = np.random.RandomState(split_seed) + + # Choose the data points that are used training (training + validation data). + all_indices = np.arange(num_data) + numpy_rng.shuffle(all_indices) + # We sort the indices after extracting them from the shuffled list, since we iteratively load the data with the + # data loader. + training_and_validation_indices = np.sort(all_indices[:(num_train + num_valid)]) + test_indices = np.sort(all_indices[(num_train + num_valid):]) + + # Cutoff is in Angstrom, so we have to divide the cutoff by the length unit. + training_and_validation_data, data_stats = loader.load( + cutoff=config.model.cutoff / length_unit, + pick_idx=training_and_validation_indices + ) + # Since the training and validation indices are sorted, the index i at the n-th entry in + # training_and_validation_indices corresponds to the n-th entry in training_and_validation_data which is the i-th + # data entry in the loaded data. + split_indices = np.arange(num_train + num_valid) + numpy_rng.shuffle(split_indices) + internal_train_indices = split_indices[:num_train] + internal_validation_indices = split_indices[num_train:] + + training_data = [training_and_validation_data[i_train] for i_train in internal_train_indices] + validation_data = [training_and_validation_data[i_val] for i_val in internal_validation_indices] + del training_and_validation_data + + assert len(internal_train_indices) == num_train + assert len(internal_validation_indices) == num_valid + + if config.data.shift_mode == 'mean': + config.data.energy_shifts = config_dict.placeholder(dict) + energy_mean = data.transformations.calculate_energy_mean(training_data) * energy_unit + num_nodes = data.transformations.calculate_average_number_of_nodes(training_data) + energy_shifts = {str(a): float(energy_mean / num_nodes) for a in range(119)} + config.data.energy_shifts = energy_shifts + elif config.data.shift_mode == 'custom': + if config.data.energy_shifts is None: + raise ValueError('For config.data.shift_mode == custom config.data.energy_shifts must be given.') + else: + config.data.energy_shifts = {str(a): 0. for a in range(119)} + + # If messages are normalized by the average number of neighbors, we need to calculate this quantity from the + # training data. + if config.model.message_normalization == 'avg_num_neighbors': + config.data.avg_num_neighbors = config_dict.placeholder(float) + avg_num_neighbors = data.transformations.calculate_average_number_of_neighbors(training_data) + config.data.avg_num_neighbors = np.array(avg_num_neighbors).item() + + training_data = list(data.transformations.subtract_atomic_energy_shifts( + data.transformations.unit_conversion( + training_data, + energy_unit=energy_unit, + length_unit=length_unit + ), + atomic_energy_shifts={int(k): v for (k, v) in config.data.energy_shifts.items()} + )) + + validation_data = list(data.transformations.subtract_atomic_energy_shifts( + data.transformations.unit_conversion( + validation_data, + energy_unit=energy_unit, + length_unit=length_unit + ), + atomic_energy_shifts={int(k): v for (k, v) in config.data.energy_shifts.items()} + )) + + opt = make_optimizer_from_config(config) + + # TODO: One could load the model from the original workdir itself, but this would mean to either have a specific + # fine_tuning_config or to silently ignore the model config in the config file. For now one has to make sure to + # define a suited model from config such that for now responsibility lies at the user. And code breaks if it is + # not done properly so is directly visible by user. + so3k = make_so3krates_sparse_from_config(config) + + loss_fn = training_utils.make_loss_fn( + get_energy_and_force_fn_sparse(so3k), + weights=config.training.loss_weights + ) + + if config.training.batch_max_num_nodes is None: + assert config.training.batch_max_num_edges is None + + batch_max_num_nodes = data_stats['max_num_of_nodes'] * (config.training.batch_max_num_graphs - 1) + 1 + batch_max_num_edges = data_stats['max_num_of_edges'] * (config.training.batch_max_num_graphs - 1) + 1 + + config.training.batch_max_num_nodes = batch_max_num_nodes + config.training.batch_max_num_edges = batch_max_num_edges + + # internal_*_indices only run from [0, num_train+num_valid]. To get their original position in the full data set + # we collect them from training_and_validation_indices. Since we will load training and validation data as + # training_and_validation_data[internal_*_indices], we need to make sure that training_and_validation_indices + # and training_and_validation_data have the same order in the sense of referencing indices. This is achieved by + # sorting the indices as described above. + train_indices = training_and_validation_indices[internal_train_indices] + validation_indices = training_and_validation_indices[internal_validation_indices] + assert len(train_indices) == num_train + assert len(validation_indices) == num_valid + with open(workdir / 'data_splits.json', 'w') as fp: + j = dict( + training=train_indices.tolist(), + validation=validation_indices.tolist(), + test=test_indices.tolist() + ) + json.dump(j, fp) + + with open(workdir / 'hyperparameters.json', 'w') as fp: + # json_config = config.to_dict() + # energy_shifts = json_config['data']['energy_shifts'] + # energy_shifts = jax.tree_map(lambda x: x.item(), energy_shifts) + json.dump(config.to_dict(), fp) + + with open(workdir / "hyperparameters.yaml", "w") as yaml_file: + yaml.dump(config.to_dict(), yaml_file, default_flow_style=False) + + wandb.init(config=config.to_dict(), **config.training.wandb_init_args) + + logging.mlff( + f'Fine tuning model from {start_from_workdir} on {data_filepath}!' + ) + training_utils.fit( + model=so3k, + optimizer=opt, + loss_fn=loss_fn, + graph_to_batch_fn=jraph_utils.graph_to_batch_fn, + batch_max_num_edges=config.training.batch_max_num_edges, + batch_max_num_nodes=config.training.batch_max_num_nodes, + batch_max_num_graphs=config.training.batch_max_num_graphs, + training_data=training_data, + validation_data=validation_data, + params=params, + ckpt_dir=workdir / 'checkpoints', + eval_every_num_steps=config.training.eval_every_num_steps, + allow_restart=config.training.allow_restart, + num_epochs=config.training.num_epochs, + training_seed=config.training.training_seed, + model_seed=config.training.model_seed, + log_gradient_values=config.training.log_gradient_values + ) + logging.mlff('Training has finished!') + + +def load_params_from_workdir(workdir): + """Load parameters from workdir. + + Args: + workdir (str): Path to `workdir`. + + Returns: + PyTree of parameters. + + Raises: + ValueError: Workdir does not have a checkpoint directory. + RuntimeError: Loaded parameters are None. + + """ + load_path = Path(workdir).expanduser().resolve() / "checkpoints" + if not load_path.exists(): + raise ValueError( + f'Trying to load parameters from {load_path} but path does not exist.' + ) + + loaded_mngr = checkpoint.CheckpointManager( + load_path, + { + "params": checkpoint.PyTreeCheckpointer(), + }, + options=checkpoint.CheckpointManagerOptions(step_prefix="ckpt"), + ) + mgr_state = loaded_mngr.restore( + loaded_mngr.latest_step(), + { + "params": checkpoint.PyTreeCheckpointer(), + }) + params = mgr_state.get("params") + + if params is None: + raise RuntimeError( + f'Parameters loaded from {load_path} are None.' + ) + + del loaded_mngr + + return params diff --git a/mlff/utils/training_utils.py b/mlff/utils/training_utils.py index 0944c09..7ffcc7d 100644 --- a/mlff/utils/training_utils.py +++ b/mlff/utils/training_utils.py @@ -127,11 +127,12 @@ def fit( batch_max_num_nodes, batch_max_num_edges, batch_max_num_graphs, + params=None, num_epochs: int = 100, ckpt_dir: str = None, ckpt_manager_options: dict = None, eval_every_num_steps: int = 1000, - allow_restart=False, + allow_restart: bool = False, training_seed: int = 0, model_seed: int = 0, use_wandb: bool = True, @@ -150,6 +151,8 @@ def fit( batch_max_num_nodes (int): Maximal number of nodes per batch. batch_max_num_edges (int): Maximal number of edges per batch. batch_max_num_graphs (int): Maximal number of graphs per batch. + params: Parameters to start from during training. If not given, either new parameters are initialized randomly + or loaded from ckpt_dir if the checkpoint already exists and `allow_restart=True`. num_epochs (int): Number of training epochs. ckpt_dir (str): Checkpoint path. ckpt_manager_options (dict): Checkpoint manager options. @@ -193,7 +196,6 @@ def fit( processed_nodes = 0 step = 0 - params = None opt_state = None for epoch in range(num_epochs): # Shuffle the training data. @@ -215,8 +217,8 @@ def fit( # Training data is numpy arrays so we now transform them to jax.numpy arrays. batch_training = jax.tree_map(jnp.array, batch_training) - # In the first step, initialize the parameters or load from existing checkpoint. - if step == 0: + # If params are None (in the first step), initialize the parameters or load from existing checkpoint. + if params is None: # Check if checkpoint already exists. latest_step = ckpt_mngr.latest_step() if latest_step is not None: @@ -233,6 +235,8 @@ def fit( else: params = model.init(jax_rng, batch_training) + # If optimizer state is None (in the first step), initialize from the parameter pyTree. + if opt_state is None: opt_state = optimizer.init(params) # Make sure parameters and opt_state are set. diff --git a/setup.py b/setup.py index 1b9f151..d30bf15 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ "trajectory_to_xyz=mlff.cAPI.mlff_postprocessing:trajectory_to_xyz", "to_mlff_input=mlff.cAPI.mlff_input_processing:to_mlff_input", "train_so3krates_sparse=mlff.CLI.run_training:train_so3krates_sparse", + "fine_tune_so3krates_sparse=mlff.CLI.run_fine_tuning:fine_tune_so3krates_sparse", "evaluate_so3krates_sparse=mlff.CLI.run_evaluation:evaluate_so3krates_sparse", "evaluate_so3krates_sparse_on=mlff.CLI.run_evaluation_on:evaluate_so3krates_sparse_on" ],