Skip to content

Commit

Permalink
first draft for finetuning logic
Browse files Browse the repository at this point in the history
  • Loading branch information
thorben-frank committed Jan 31, 2024
1 parent abe2037 commit f97cb6e
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 4 deletions.
39 changes: 39 additions & 0 deletions mlff/CLI/run_fine_tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import argparse
import json
from mlff.config import from_config
from ml_collections import config_dict
import pathlib
import yaml


def fine_tune_so3krates_sparse():
# Create the parser
parser = argparse.ArgumentParser(description='Train a SO3kratesSparse model.')
parser.add_argument(
'--config',
type=str,
required=True,
help='Path to the config file.'
)
parser.add_argument(
'--start_from_workdir',
type=str,
required=True,
help='Path to workdir from which fine tuning should be started.'
)

args = parser.parse_args()

config = pathlib.Path(args.config).expanduser().absolute().resolve()
if config.suffix == '.json':
with open(config, mode='r') as fp:
cfg = config_dict.ConfigDict(json.load(fp=fp))
elif config.suffix == '.yaml':
with open(config, mode='r') as fp:
cfg = config_dict.ConfigDict(yaml.load(fp, Loader=yaml.FullLoader))

from_config.run_fine_tuning(cfg, start_from_workdir=args.start_from_workdir)


if __name__ == '__main__':
fine_tune_so3krates_sparse()
242 changes: 242 additions & 0 deletions mlff/config/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,245 @@ def run_evaluation(
batch_max_num_graphs=config.training.batch_max_num_graphs,
write_batch_metrics_to=write_batch_metrics_to
)


def run_fine_tuning(
config: config_dict.ConfigDict,
start_from_workdir: str
):
"""Run training given a config.
Args:
config (): The config.
start_from_workdir (str): The workdir of the model that should be fine tuned.
Returns:
"""
workdir = Path(config.workdir).expanduser().resolve()
if workdir.exists():
raise ValueError(
f'Please specify new workdir for fine tuning. Workdir {workdir} already exists.'
)
workdir.mkdir(exist_ok=False)

start_from_workdir = Path(start_from_workdir).expanduser().resolve()
if not start_from_workdir.exists():
raise ValueError(
f'Trying to start fine tuning from {start_from_workdir} but directory does not exist.'
)
params = load_params_from_workdir(start_from_workdir)

data_filepath = config.data.filepath
data_filepath = Path(data_filepath).expanduser().resolve()

if data_filepath.suffix == '.npz':
loader = data.NpzDataLoaderSparse(input_file=data_filepath)
elif data_filepath.stem[:5].lower() == 'spice':
logging.mlff(f'Found SPICE dataset at {data_filepath}.')
if data_filepath.suffix != '.hdf5':
raise ValueError(
f'Loader assumes that SPICE is in hdf5 format. Found {data_filepath.suffix} as'
f'suffix.')
loader = data.SpiceDataLoaderSparse(input_file=data_filepath)
else:
loader = data.AseDataLoaderSparse(input_file=data_filepath)

# Get the units of the data.
energy_unit = eval(config.data.energy_unit)
length_unit = eval(config.data.length_unit)

# Get the total number of data points.
num_data = loader.cardinality()
num_train = config.training.num_train
num_valid = config.training.num_valid

if num_train + num_valid > num_data:
raise ValueError(f"num_train + num_valid = {num_train + num_valid} exceeds the number of data points {num_data}"
f" in {data_filepath}.")

split_seed = config.data.split_seed
numpy_rng = np.random.RandomState(split_seed)

# Choose the data points that are used training (training + validation data).
all_indices = np.arange(num_data)
numpy_rng.shuffle(all_indices)
# We sort the indices after extracting them from the shuffled list, since we iteratively load the data with the
# data loader.
training_and_validation_indices = np.sort(all_indices[:(num_train + num_valid)])
test_indices = np.sort(all_indices[(num_train + num_valid):])

# Cutoff is in Angstrom, so we have to divide the cutoff by the length unit.
training_and_validation_data, data_stats = loader.load(
cutoff=config.model.cutoff / length_unit,
pick_idx=training_and_validation_indices
)
# Since the training and validation indices are sorted, the index i at the n-th entry in
# training_and_validation_indices corresponds to the n-th entry in training_and_validation_data which is the i-th
# data entry in the loaded data.
split_indices = np.arange(num_train + num_valid)
numpy_rng.shuffle(split_indices)
internal_train_indices = split_indices[:num_train]
internal_validation_indices = split_indices[num_train:]

training_data = [training_and_validation_data[i_train] for i_train in internal_train_indices]
validation_data = [training_and_validation_data[i_val] for i_val in internal_validation_indices]
del training_and_validation_data

assert len(internal_train_indices) == num_train
assert len(internal_validation_indices) == num_valid

if config.data.shift_mode == 'mean':
config.data.energy_shifts = config_dict.placeholder(dict)
energy_mean = data.transformations.calculate_energy_mean(training_data) * energy_unit
num_nodes = data.transformations.calculate_average_number_of_nodes(training_data)
energy_shifts = {str(a): float(energy_mean / num_nodes) for a in range(119)}
config.data.energy_shifts = energy_shifts
elif config.data.shift_mode == 'custom':
if config.data.energy_shifts is None:
raise ValueError('For config.data.shift_mode == custom config.data.energy_shifts must be given.')
else:
config.data.energy_shifts = {str(a): 0. for a in range(119)}

# If messages are normalized by the average number of neighbors, we need to calculate this quantity from the
# training data.
if config.model.message_normalization == 'avg_num_neighbors':
config.data.avg_num_neighbors = config_dict.placeholder(float)
avg_num_neighbors = data.transformations.calculate_average_number_of_neighbors(training_data)
config.data.avg_num_neighbors = np.array(avg_num_neighbors).item()

training_data = list(data.transformations.subtract_atomic_energy_shifts(
data.transformations.unit_conversion(
training_data,
energy_unit=energy_unit,
length_unit=length_unit
),
atomic_energy_shifts={int(k): v for (k, v) in config.data.energy_shifts.items()}
))

validation_data = list(data.transformations.subtract_atomic_energy_shifts(
data.transformations.unit_conversion(
validation_data,
energy_unit=energy_unit,
length_unit=length_unit
),
atomic_energy_shifts={int(k): v for (k, v) in config.data.energy_shifts.items()}
))

opt = make_optimizer_from_config(config)

# TODO: One could load the model from the original workdir itself, but this would mean to either have a specific
# fine_tuning_config or to silently ignore the model config in the config file. For now one has to make sure to
# define a suited model from config such that for now responsibility lies at the user. And code breaks if it is
# not done properly so is directly visible by user.
so3k = make_so3krates_sparse_from_config(config)

loss_fn = training_utils.make_loss_fn(
get_energy_and_force_fn_sparse(so3k),
weights=config.training.loss_weights
)

if config.training.batch_max_num_nodes is None:
assert config.training.batch_max_num_edges is None

batch_max_num_nodes = data_stats['max_num_of_nodes'] * (config.training.batch_max_num_graphs - 1) + 1
batch_max_num_edges = data_stats['max_num_of_edges'] * (config.training.batch_max_num_graphs - 1) + 1

config.training.batch_max_num_nodes = batch_max_num_nodes
config.training.batch_max_num_edges = batch_max_num_edges

# internal_*_indices only run from [0, num_train+num_valid]. To get their original position in the full data set
# we collect them from training_and_validation_indices. Since we will load training and validation data as
# training_and_validation_data[internal_*_indices], we need to make sure that training_and_validation_indices
# and training_and_validation_data have the same order in the sense of referencing indices. This is achieved by
# sorting the indices as described above.
train_indices = training_and_validation_indices[internal_train_indices]
validation_indices = training_and_validation_indices[internal_validation_indices]
assert len(train_indices) == num_train
assert len(validation_indices) == num_valid
with open(workdir / 'data_splits.json', 'w') as fp:
j = dict(
training=train_indices.tolist(),
validation=validation_indices.tolist(),
test=test_indices.tolist()
)
json.dump(j, fp)

with open(workdir / 'hyperparameters.json', 'w') as fp:
# json_config = config.to_dict()
# energy_shifts = json_config['data']['energy_shifts']
# energy_shifts = jax.tree_map(lambda x: x.item(), energy_shifts)
json.dump(config.to_dict(), fp)

with open(workdir / "hyperparameters.yaml", "w") as yaml_file:
yaml.dump(config.to_dict(), yaml_file, default_flow_style=False)

wandb.init(config=config.to_dict(), **config.training.wandb_init_args)

logging.mlff(
f'Fine tuning model from {start_from_workdir} on {data_filepath}!'
)
training_utils.fit(
model=so3k,
optimizer=opt,
loss_fn=loss_fn,
graph_to_batch_fn=jraph_utils.graph_to_batch_fn,
batch_max_num_edges=config.training.batch_max_num_edges,
batch_max_num_nodes=config.training.batch_max_num_nodes,
batch_max_num_graphs=config.training.batch_max_num_graphs,
training_data=training_data,
validation_data=validation_data,
params=params,
ckpt_dir=workdir / 'checkpoints',
eval_every_num_steps=config.training.eval_every_num_steps,
allow_restart=config.training.allow_restart,
num_epochs=config.training.num_epochs,
training_seed=config.training.training_seed,
model_seed=config.training.model_seed,
log_gradient_values=config.training.log_gradient_values
)
logging.mlff('Training has finished!')


def load_params_from_workdir(workdir):
"""Load parameters from workdir.
Args:
workdir (str): Path to `workdir`.
Returns:
PyTree of parameters.
Raises:
ValueError: Workdir does not have a checkpoint directory.
RuntimeError: Loaded parameters are None.
"""
load_path = Path(workdir).expanduser().resolve() / "checkpoints"
if not load_path.exists():
raise ValueError(
f'Trying to load parameters from {load_path} but path does not exist.'
)

loaded_mngr = checkpoint.CheckpointManager(
load_path,
{
"params": checkpoint.PyTreeCheckpointer(),
},
options=checkpoint.CheckpointManagerOptions(step_prefix="ckpt"),
)
mgr_state = loaded_mngr.restore(
loaded_mngr.latest_step(),
{
"params": checkpoint.PyTreeCheckpointer(),
})
params = mgr_state.get("params")

if params is None:
raise RuntimeError(
f'Parameters loaded from {load_path} are None.'
)

del loaded_mngr

return params
12 changes: 8 additions & 4 deletions mlff/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,12 @@ def fit(
batch_max_num_nodes,
batch_max_num_edges,
batch_max_num_graphs,
params=None,
num_epochs: int = 100,
ckpt_dir: str = None,
ckpt_manager_options: dict = None,
eval_every_num_steps: int = 1000,
allow_restart=False,
allow_restart: bool = False,
training_seed: int = 0,
model_seed: int = 0,
use_wandb: bool = True,
Expand All @@ -150,6 +151,8 @@ def fit(
batch_max_num_nodes (int): Maximal number of nodes per batch.
batch_max_num_edges (int): Maximal number of edges per batch.
batch_max_num_graphs (int): Maximal number of graphs per batch.
params: Parameters to start from during training. If not given, either new parameters are initialized randomly
or loaded from ckpt_dir if the checkpoint already exists and `allow_restart=True`.
num_epochs (int): Number of training epochs.
ckpt_dir (str): Checkpoint path.
ckpt_manager_options (dict): Checkpoint manager options.
Expand Down Expand Up @@ -193,7 +196,6 @@ def fit(
processed_nodes = 0
step = 0

params = None
opt_state = None
for epoch in range(num_epochs):
# Shuffle the training data.
Expand All @@ -215,8 +217,8 @@ def fit(
# Training data is numpy arrays so we now transform them to jax.numpy arrays.
batch_training = jax.tree_map(jnp.array, batch_training)

# In the first step, initialize the parameters or load from existing checkpoint.
if step == 0:
# If params are None (in the first step), initialize the parameters or load from existing checkpoint.
if params is None:
# Check if checkpoint already exists.
latest_step = ckpt_mngr.latest_step()
if latest_step is not None:
Expand All @@ -233,6 +235,8 @@ def fit(
else:
params = model.init(jax_rng, batch_training)

# If optimizer state is None (in the first step), initialize from the parameter pyTree.
if opt_state is None:
opt_state = optimizer.init(params)

# Make sure parameters and opt_state are set.
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"trajectory_to_xyz=mlff.cAPI.mlff_postprocessing:trajectory_to_xyz",
"to_mlff_input=mlff.cAPI.mlff_input_processing:to_mlff_input",
"train_so3krates_sparse=mlff.CLI.run_training:train_so3krates_sparse",
"fine_tune_so3krates_sparse=mlff.CLI.run_fine_tuning:fine_tune_so3krates_sparse",
"evaluate_so3krates_sparse=mlff.CLI.run_evaluation:evaluate_so3krates_sparse",
"evaluate_so3krates_sparse_on=mlff.CLI.run_evaluation_on:evaluate_so3krates_sparse_on"
],
Expand Down

0 comments on commit f97cb6e

Please sign in to comment.