Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ITPNet #26

Merged
merged 1 commit into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions mlff/CLI/run_training_itp_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import argparse
import json
from mlff.config import from_config
from ml_collections import config_dict
import pathlib
import yaml


def train_itp_net():
# Create the parser
parser = argparse.ArgumentParser(description='Train a SO3kratesSparse model.')
parser.add_argument('--config', type=str, required=True, help='Path to the config file.')

args = parser.parse_args()

config = pathlib.Path(args.config).expanduser().absolute().resolve()
if config.suffix == '.json':
with open(config, mode='r') as fp:
cfg = config_dict.ConfigDict(json.load(fp=fp))
elif config.suffix == '.yaml':
with open(config, mode='r') as fp:
cfg = config_dict.ConfigDict(yaml.load(fp, Loader=yaml.FullLoader))

from_config.run_training(cfg, model='itp_net')


if __name__ == '__main__':
train_itp_net()
67 changes: 67 additions & 0 deletions mlff/config/config_itp_net.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
workdir: first_experiment_itp # Working directory. Checkpoints and hyperparameters are saved there.
data:
filepath: null # Path to the data file. Either ASE digestible or .npz with appropriate column names are supported.
energy_unit: eV # Energy unit.
length_unit: Angstrom # Length unit.
shift_mode: null # Options are null, mean, custom.
energy_shifts: null # Energy shifts to subtract.
split_seed: 0 # Seed using for splitting the data into training, validation and test.
model:
num_features: 128 # Number of invariant features.
radial_basis_fn: reciprocal_bernstein # Radial basis function to use.
num_radial_basis_fn: 32 # Number of radial basis functions.
cutoff: 5.0 # Local cutoff to use.
cutoff_fn: smooth_cutoff # Cutoff function to use.
filter_num_layers: 2 # Number of filter layers.
filter_activation_fn: identity # Activation function for the filter.
mp_max_degree: 2
mp_post_res_block: true
mp_post_res_block_activation_fn: identity
itp_num_features: 16
itp_max_degree: 2
itp_num_updates: 2
itp_post_res_block: true
itp_post_res_block_activation_fn: identity
itp_connectivity: dense
feature_collection_over_layers: last
include_pseudotensors: false
message_normalization: avg_num_neighbors # How to normalize the message function. Options are (identity, sqrt_num_features, avg_num_neighbors)
output_is_zero_at_init: true # The output of the full network is zero at initialization.
energy_regression_dim: 128 # Dimension to which final features are projected, before atomic energies are calculated.
energy_activation_fn: identity # Activation function to use on the energy_regression_dim before atomic energies are calculated.
energy_learn_atomic_type_scales: false
energy_learn_atomic_type_shifts: false
input_convention: positions # Input convention.
optimizer:
name: adam # Name of the optimizer. See https://optax.readthedocs.io/en/latest/api.html#common-optimizers for available ones.
learning_rate: 0.001 # Learning rate to use.
learning_rate_schedule: exponential_decay # Which learning rate schedule to use. See https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules for available ones.
learning_rate_schedule_args: # Arguments passed to the learning rate schedule. See https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules.
decay_rate: 0.75
transition_steps: 125000
gradient_clipping: identity
gradient_clipping_args: null
num_of_nans_to_ignore: 0 # Number of repeated update/gradient steps that ignore NaNs before raising on error.
training:
allow_restart: false # Re-starting from checkpoint is allowed. This will overwrite existing checkpoints so only use if this is desired.
num_epochs: 100 # Number of epochs.
num_train: 950 # Number of training points to draw from data.filepath.
num_valid: 50 # Number of validation points to draw from data.filepath.
batch_max_num_nodes: null # Maximal number of nodes per batch. Must be at least maximal number of atoms + 1 in the data set.
batch_max_num_edges: null # Maximal number of edges per batch. Must be at least maximal number of edges + 1 in the data set.
# If batch_max_num_nodes and batch_max_num_edges is set to null, they will be determined from the max_num_of_graphs.
# If they are set to values, each batch will contain as many molecular structures/graphs such none of the three values
# batch_max_num_nodes, batch_max_num_edges and batch_max_num_of_graphs is exceeded.
batch_max_num_graphs: 6 # Maximal number of graphs per batch.
# Since there is one padding graph involved for an effective batch size of 5 corresponds to 6 max_num_graphs.
eval_every_num_steps: 1000 # Number of gradient steps after which the metrics on the validation set are calculated.
loss_weights:
energy: 0.01 # Loss weight for the energy.
forces: 0.99 # Loss weight for the forces.
model_seed: 0 # Seed used for the initialization of the model parameters.
training_seed: 0 # Seed used for shuffling the batches during training.
log_gradient_values: False # Log the norm of the gradients for each set of weights.
wandb_init_args: # Arguments to wandb.init(). See https://docs.wandb.ai/ref/python/init. The config itself is passed as config to wandb.init().
name: first_training_run
project: mlff
group: null
56 changes: 52 additions & 4 deletions mlff/config/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,46 @@ def make_so3krates_sparse_from_config(config: config_dict.ConfigDict = None):
)


def make_itp_net_from_config(config: config_dict.ConfigDict):
"""Make an iterated tensor product model from a config.

Args:
config (): The config.

Returns:
ITP flax model.
"""

model_config = config.model

return nn.ITPNet(
num_features=model_config.num_features,
radial_basis_fn=model_config.radial_basis_fn,
num_radial_basis_fn=model_config.num_radial_basis_fn,
cutoff_fn=model_config.cutoff_fn,
filter_num_layers=model_config.filter_num_layers,
filter_activation_fn=model_config.filter_activation_fn,
mp_max_degree=model_config.mp_max_degree,
mp_post_res_block=model_config.mp_post_res_block,
mp_post_res_block_activation_fn=model_config.mp_post_res_block_activation_fn,
itp_max_degree=model_config.itp_max_degree,
itp_num_features=model_config.itp_num_features,
itp_post_res_block=model_config.itp_post_res_block,
itp_post_res_block_activation_fn=model_config.itp_post_res_block_activation_fn,
itp_connectivity=model_config.itp_connectivity,
feature_collection_over_layers=model_config.feature_collection_over_layers,
include_pseudotensors=model_config.include_pseudotensors,
message_normalization=config.model.message_normalization,
avg_num_neighbors=config.data.avg_num_neighbors if config.model.message_normalization == 'avg_num_neighbors' else None,
output_is_zero_at_init=model_config.output_is_zero_at_init,
input_convention=model_config.input_convention,
energy_regression_dim=model_config.energy_regression_dim,
energy_activation_fn=model_config.energy_activation_fn,
energy_learn_atomic_type_scales=model_config.energy_learn_atomic_type_scales,
energy_learn_atomic_type_shifts=model_config.energy_learn_atomic_type_shifts
)


def make_optimizer_from_config(config: config_dict.ConfigDict = None):
"""Make optax optimizer from config.

Expand All @@ -85,11 +125,12 @@ def make_optimizer_from_config(config: config_dict.ConfigDict = None):
)


def run_training(config: config_dict.ConfigDict):
def run_training(config: config_dict.ConfigDict, model: str = None):
"""Run training given a config.

Args:
config (): The config.
model (): The model to train. Defaults to SO3krates.

Returns:

Expand Down Expand Up @@ -190,10 +231,17 @@ def run_training(config: config_dict.ConfigDict):
))

opt = make_optimizer_from_config(config)
so3k = make_so3krates_sparse_from_config(config)
if model is None or model == 'so3krates':
net = make_so3krates_sparse_from_config(config)
elif model == 'itp_net':
net = make_itp_net_from_config(config)
else:
raise ValueError(
f'{model=} is not a valid model.'
)

loss_fn = training_utils.make_loss_fn(
get_energy_and_force_fn_sparse(so3k),
get_energy_and_force_fn_sparse(net),
weights=config.training.loss_weights
)

Expand Down Expand Up @@ -238,7 +286,7 @@ def run_training(config: config_dict.ConfigDict):
wandb.init(config=config.to_dict(), **config.training.wandb_init_args)
logging.mlff('Training is starting!')
training_utils.fit(
model=so3k,
model=net,
optimizer=opt,
loss_fn=loss_fn,
graph_to_batch_fn=jraph_utils.graph_to_batch_fn,
Expand Down
5 changes: 3 additions & 2 deletions mlff/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .representation import (So3krates,
So3kratACE,
SchNet,
SO3kratesSparse)
SO3kratesSparse,
ITPNet)

from .stacknet import (get_observable_fn,
get_energy_force_stress_fn,
Expand All @@ -15,4 +16,4 @@
from .observable import (Energy,
ZBLRepulsion)

from .embed import GeometryEmbedSparse
from .embed import GeometryEmbedSparse, GeometryEmbedE3x
1 change: 1 addition & 0 deletions mlff/nn/embed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .embed_sparse import (
GeometryEmbedSparse,
GeometryEmbedE3x,
AtomTypeEmbedSparse
)

Expand Down
59 changes: 59 additions & 0 deletions mlff/nn/embed/embed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import (Any, Dict, Sequence)

import flax.linen as nn
import e3x

from mlff.nn.base.sub_module import BaseSubModule
from mlff.masking.mask import safe_mask
Expand All @@ -14,6 +15,64 @@
from mlff.basis_function.spherical import init_sph_fn


class GeometryEmbedE3x(BaseSubModule):
prop_keys: Dict
max_degree: int
radial_basis_fn: str
num_radial_basis_fn: int
cutoff_fn: str
cutoff: float
input_convention: str = 'positions'
module_name: str = 'geometry_embed_e3x'

def __call__(self, inputs, *args, **kwargs):

idx_i = inputs['idx_i'] # shape: (num_pairs)
idx_j = inputs['idx_j'] # shape: (num_pairs)
cell = inputs.get('cell') # shape: (num_graphs, 3, 3)
cell_offsets = inputs.get('cell_offset') # shape: (num_pairs, 3)

if self.input_convention == 'positions':
positions = inputs['positions'] # (N, 3)

# Calculate pairwise distance vectors
r_ij = jax.vmap(
lambda i, j: positions[j] - positions[i]
)(idx_i, idx_j) # (num_pairs, 3)

# Apply minimal image convention if needed.
if cell is not None:
r_ij = add_cell_offsets_sparse(
r_ij=r_ij,
cell=cell,
cell_offsets=cell_offsets
) # shape: (num_pairs,3)

# Here it is assumed that PBC (if present) have already been respected in displacement calculation.
elif self.input_convention == 'displacements':
positions = None
r_ij = inputs['displacements']
else:
raise ValueError(f"{self.input_convention} is not a valid argument for `input_convention`.")

basis, cut = e3x.nn.basis(
r=r_ij,
max_degree=self.max_degree,
radial_fn=getattr(e3x.nn, self.radial_basis_fn),
num=self.num_radial_basis_fn,
cutoff_fn=partial(getattr(e3x.nn, self.cutoff_fn), cutoff=self.cutoff),
return_cutoff=True
) # (N, 1, (max_degree+1)^2, num_radial_basis_fn), (N, )

geometric_data = {'positions': positions,
'basis': basis,
'r_ij': r_ij,
'cut': cut,
}

return geometric_data


class GeometryEmbedSparse(BaseSubModule):
prop_keys: Dict
degrees: Sequence[int]
Expand Down
3 changes: 3 additions & 0 deletions mlff/nn/embed/h_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .embed_sparse import (
GeometryEmbedSparse,
GeometryEmbedE3x,
AtomTypeEmbedSparse
)

Expand All @@ -19,6 +20,8 @@ def get_embedding_module(name: str, h: Dict):
return AtomTypeEmbedSparse(**h)
elif name == 'geometry_embed':
return GeometryEmbed(**h)
elif name == 'geometry_embed_e3x':
return GeometryEmbedE3x(**h)
elif name == 'geometry_embed_sparse':
return GeometryEmbedSparse(**h)
elif name == 'one_hot_embed':
Expand Down
1 change: 1 addition & 0 deletions mlff/nn/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .so3krates_layer import So3kratesLayer
from .so3kratace_layer import So3krataceLayer
from .so3krates_layer_sparse import SO3kratesLayerSparse
from .itp_layer import ITPLayer
from .h_register import get_layer
3 changes: 3 additions & 0 deletions mlff/nn/layer/h_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .so3krates_layer import So3kratesLayer
from .so3kratace_layer import So3krataceLayer
from .schnet_layer import SchNetLayer
from .itp_layer import ITPLayer


def get_layer(name: str, h: Dict):
Expand All @@ -11,6 +12,8 @@ def get_layer(name: str, h: Dict):
return So3krataceLayer(**h)
elif name == 'schnet_layer':
return SchNetLayer(**h)
elif name == 'itp_layer':
return ITPLayer(**h)
elif name == 'spookynet_layer':
raise NotImplementedError('SpookyNet not implemented!')
return SpookyNetLayer(**h)
Expand Down
Loading
Loading