diff --git a/mlff/CLI/run_training_itp_net.py b/mlff/CLI/run_training_itp_net.py new file mode 100644 index 0000000..0365eeb --- /dev/null +++ b/mlff/CLI/run_training_itp_net.py @@ -0,0 +1,28 @@ +import argparse +import json +from mlff.config import from_config +from ml_collections import config_dict +import pathlib +import yaml + + +def train_itp_net(): + # Create the parser + parser = argparse.ArgumentParser(description='Train a SO3kratesSparse model.') + parser.add_argument('--config', type=str, required=True, help='Path to the config file.') + + args = parser.parse_args() + + config = pathlib.Path(args.config).expanduser().absolute().resolve() + if config.suffix == '.json': + with open(config, mode='r') as fp: + cfg = config_dict.ConfigDict(json.load(fp=fp)) + elif config.suffix == '.yaml': + with open(config, mode='r') as fp: + cfg = config_dict.ConfigDict(yaml.load(fp, Loader=yaml.FullLoader)) + + from_config.run_training(cfg, model='itp_net') + + +if __name__ == '__main__': + train_itp_net() diff --git a/mlff/config/config_itp_net.yaml b/mlff/config/config_itp_net.yaml new file mode 100644 index 0000000..9fc3112 --- /dev/null +++ b/mlff/config/config_itp_net.yaml @@ -0,0 +1,67 @@ +workdir: first_experiment_itp # Working directory. Checkpoints and hyperparameters are saved there. +data: + filepath: null # Path to the data file. Either ASE digestible or .npz with appropriate column names are supported. + energy_unit: eV # Energy unit. + length_unit: Angstrom # Length unit. + shift_mode: null # Options are null, mean, custom. + energy_shifts: null # Energy shifts to subtract. + split_seed: 0 # Seed using for splitting the data into training, validation and test. +model: + num_features: 128 # Number of invariant features. + radial_basis_fn: reciprocal_bernstein # Radial basis function to use. + num_radial_basis_fn: 32 # Number of radial basis functions. + cutoff: 5.0 # Local cutoff to use. + cutoff_fn: smooth_cutoff # Cutoff function to use. + filter_num_layers: 2 # Number of filter layers. + filter_activation_fn: identity # Activation function for the filter. + mp_max_degree: 2 + mp_post_res_block: true + mp_post_res_block_activation_fn: identity + itp_num_features: 16 + itp_max_degree: 2 + itp_num_updates: 2 + itp_post_res_block: true + itp_post_res_block_activation_fn: identity + itp_connectivity: dense + feature_collection_over_layers: last + include_pseudotensors: false + message_normalization: avg_num_neighbors # How to normalize the message function. Options are (identity, sqrt_num_features, avg_num_neighbors) + output_is_zero_at_init: true # The output of the full network is zero at initialization. + energy_regression_dim: 128 # Dimension to which final features are projected, before atomic energies are calculated. + energy_activation_fn: identity # Activation function to use on the energy_regression_dim before atomic energies are calculated. + energy_learn_atomic_type_scales: false + energy_learn_atomic_type_shifts: false + input_convention: positions # Input convention. +optimizer: + name: adam # Name of the optimizer. See https://optax.readthedocs.io/en/latest/api.html#common-optimizers for available ones. + learning_rate: 0.001 # Learning rate to use. + learning_rate_schedule: exponential_decay # Which learning rate schedule to use. See https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules for available ones. + learning_rate_schedule_args: # Arguments passed to the learning rate schedule. See https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules. + decay_rate: 0.75 + transition_steps: 125000 + gradient_clipping: identity + gradient_clipping_args: null + num_of_nans_to_ignore: 0 # Number of repeated update/gradient steps that ignore NaNs before raising on error. +training: + allow_restart: false # Re-starting from checkpoint is allowed. This will overwrite existing checkpoints so only use if this is desired. + num_epochs: 100 # Number of epochs. + num_train: 950 # Number of training points to draw from data.filepath. + num_valid: 50 # Number of validation points to draw from data.filepath. + batch_max_num_nodes: null # Maximal number of nodes per batch. Must be at least maximal number of atoms + 1 in the data set. + batch_max_num_edges: null # Maximal number of edges per batch. Must be at least maximal number of edges + 1 in the data set. + # If batch_max_num_nodes and batch_max_num_edges is set to null, they will be determined from the max_num_of_graphs. + # If they are set to values, each batch will contain as many molecular structures/graphs such none of the three values + # batch_max_num_nodes, batch_max_num_edges and batch_max_num_of_graphs is exceeded. + batch_max_num_graphs: 6 # Maximal number of graphs per batch. + # Since there is one padding graph involved for an effective batch size of 5 corresponds to 6 max_num_graphs. + eval_every_num_steps: 1000 # Number of gradient steps after which the metrics on the validation set are calculated. + loss_weights: + energy: 0.01 # Loss weight for the energy. + forces: 0.99 # Loss weight for the forces. + model_seed: 0 # Seed used for the initialization of the model parameters. + training_seed: 0 # Seed used for shuffling the batches during training. + log_gradient_values: False # Log the norm of the gradients for each set of weights. + wandb_init_args: # Arguments to wandb.init(). See https://docs.wandb.ai/ref/python/init. The config itself is passed as config to wandb.init(). + name: first_training_run + project: mlff + group: null diff --git a/mlff/config/from_config.py b/mlff/config/from_config.py index 1306424..7776890 100644 --- a/mlff/config/from_config.py +++ b/mlff/config/from_config.py @@ -61,6 +61,46 @@ def make_so3krates_sparse_from_config(config: config_dict.ConfigDict = None): ) +def make_itp_net_from_config(config: config_dict.ConfigDict): + """Make an iterated tensor product model from a config. + + Args: + config (): The config. + + Returns: + ITP flax model. + """ + + model_config = config.model + + return nn.ITPNet( + num_features=model_config.num_features, + radial_basis_fn=model_config.radial_basis_fn, + num_radial_basis_fn=model_config.num_radial_basis_fn, + cutoff_fn=model_config.cutoff_fn, + filter_num_layers=model_config.filter_num_layers, + filter_activation_fn=model_config.filter_activation_fn, + mp_max_degree=model_config.mp_max_degree, + mp_post_res_block=model_config.mp_post_res_block, + mp_post_res_block_activation_fn=model_config.mp_post_res_block_activation_fn, + itp_max_degree=model_config.itp_max_degree, + itp_num_features=model_config.itp_num_features, + itp_post_res_block=model_config.itp_post_res_block, + itp_post_res_block_activation_fn=model_config.itp_post_res_block_activation_fn, + itp_connectivity=model_config.itp_connectivity, + feature_collection_over_layers=model_config.feature_collection_over_layers, + include_pseudotensors=model_config.include_pseudotensors, + message_normalization=config.model.message_normalization, + avg_num_neighbors=config.data.avg_num_neighbors if config.model.message_normalization == 'avg_num_neighbors' else None, + output_is_zero_at_init=model_config.output_is_zero_at_init, + input_convention=model_config.input_convention, + energy_regression_dim=model_config.energy_regression_dim, + energy_activation_fn=model_config.energy_activation_fn, + energy_learn_atomic_type_scales=model_config.energy_learn_atomic_type_scales, + energy_learn_atomic_type_shifts=model_config.energy_learn_atomic_type_shifts + ) + + def make_optimizer_from_config(config: config_dict.ConfigDict = None): """Make optax optimizer from config. @@ -85,11 +125,12 @@ def make_optimizer_from_config(config: config_dict.ConfigDict = None): ) -def run_training(config: config_dict.ConfigDict): +def run_training(config: config_dict.ConfigDict, model: str = None): """Run training given a config. Args: config (): The config. + model (): The model to train. Defaults to SO3krates. Returns: @@ -190,10 +231,17 @@ def run_training(config: config_dict.ConfigDict): )) opt = make_optimizer_from_config(config) - so3k = make_so3krates_sparse_from_config(config) + if model is None or model == 'so3krates': + net = make_so3krates_sparse_from_config(config) + elif model == 'itp_net': + net = make_itp_net_from_config(config) + else: + raise ValueError( + f'{model=} is not a valid model.' + ) loss_fn = training_utils.make_loss_fn( - get_energy_and_force_fn_sparse(so3k), + get_energy_and_force_fn_sparse(net), weights=config.training.loss_weights ) @@ -238,7 +286,7 @@ def run_training(config: config_dict.ConfigDict): wandb.init(config=config.to_dict(), **config.training.wandb_init_args) logging.mlff('Training is starting!') training_utils.fit( - model=so3k, + model=net, optimizer=opt, loss_fn=loss_fn, graph_to_batch_fn=jraph_utils.graph_to_batch_fn, diff --git a/mlff/nn/__init__.py b/mlff/nn/__init__.py index 07576f9..b9bfca7 100644 --- a/mlff/nn/__init__.py +++ b/mlff/nn/__init__.py @@ -1,7 +1,8 @@ from .representation import (So3krates, So3kratACE, SchNet, - SO3kratesSparse) + SO3kratesSparse, + ITPNet) from .stacknet import (get_observable_fn, get_energy_force_stress_fn, @@ -15,4 +16,4 @@ from .observable import (Energy, ZBLRepulsion) -from .embed import GeometryEmbedSparse +from .embed import GeometryEmbedSparse, GeometryEmbedE3x diff --git a/mlff/nn/embed/__init__.py b/mlff/nn/embed/__init__.py index 14762f8..74f2d43 100644 --- a/mlff/nn/embed/__init__.py +++ b/mlff/nn/embed/__init__.py @@ -6,6 +6,7 @@ from .embed_sparse import ( GeometryEmbedSparse, + GeometryEmbedE3x, AtomTypeEmbedSparse ) diff --git a/mlff/nn/embed/embed_sparse.py b/mlff/nn/embed/embed_sparse.py index 5363245..d3a1898 100644 --- a/mlff/nn/embed/embed_sparse.py +++ b/mlff/nn/embed/embed_sparse.py @@ -5,6 +5,7 @@ from typing import (Any, Dict, Sequence) import flax.linen as nn +import e3x from mlff.nn.base.sub_module import BaseSubModule from mlff.masking.mask import safe_mask @@ -14,6 +15,64 @@ from mlff.basis_function.spherical import init_sph_fn +class GeometryEmbedE3x(BaseSubModule): + prop_keys: Dict + max_degree: int + radial_basis_fn: str + num_radial_basis_fn: int + cutoff_fn: str + cutoff: float + input_convention: str = 'positions' + module_name: str = 'geometry_embed_e3x' + + def __call__(self, inputs, *args, **kwargs): + + idx_i = inputs['idx_i'] # shape: (num_pairs) + idx_j = inputs['idx_j'] # shape: (num_pairs) + cell = inputs.get('cell') # shape: (num_graphs, 3, 3) + cell_offsets = inputs.get('cell_offset') # shape: (num_pairs, 3) + + if self.input_convention == 'positions': + positions = inputs['positions'] # (N, 3) + + # Calculate pairwise distance vectors + r_ij = jax.vmap( + lambda i, j: positions[j] - positions[i] + )(idx_i, idx_j) # (num_pairs, 3) + + # Apply minimal image convention if needed. + if cell is not None: + r_ij = add_cell_offsets_sparse( + r_ij=r_ij, + cell=cell, + cell_offsets=cell_offsets + ) # shape: (num_pairs,3) + + # Here it is assumed that PBC (if present) have already been respected in displacement calculation. + elif self.input_convention == 'displacements': + positions = None + r_ij = inputs['displacements'] + else: + raise ValueError(f"{self.input_convention} is not a valid argument for `input_convention`.") + + basis, cut = e3x.nn.basis( + r=r_ij, + max_degree=self.max_degree, + radial_fn=getattr(e3x.nn, self.radial_basis_fn), + num=self.num_radial_basis_fn, + cutoff_fn=partial(getattr(e3x.nn, self.cutoff_fn), cutoff=self.cutoff), + return_cutoff=True + ) # (N, 1, (max_degree+1)^2, num_radial_basis_fn), (N, ) + + geometric_data = {'positions': positions, + 'basis': basis, + 'r_ij': r_ij, + 'cut': cut, + } + + return geometric_data + + class GeometryEmbedSparse(BaseSubModule): prop_keys: Dict degrees: Sequence[int] diff --git a/mlff/nn/embed/h_register.py b/mlff/nn/embed/h_register.py index 658002f..6713095 100644 --- a/mlff/nn/embed/h_register.py +++ b/mlff/nn/embed/h_register.py @@ -8,6 +8,7 @@ from .embed_sparse import ( GeometryEmbedSparse, + GeometryEmbedE3x, AtomTypeEmbedSparse ) @@ -19,6 +20,8 @@ def get_embedding_module(name: str, h: Dict): return AtomTypeEmbedSparse(**h) elif name == 'geometry_embed': return GeometryEmbed(**h) + elif name == 'geometry_embed_e3x': + return GeometryEmbedE3x(**h) elif name == 'geometry_embed_sparse': return GeometryEmbedSparse(**h) elif name == 'one_hot_embed': diff --git a/mlff/nn/layer/__init__.py b/mlff/nn/layer/__init__.py index ffb1836..06639a4 100644 --- a/mlff/nn/layer/__init__.py +++ b/mlff/nn/layer/__init__.py @@ -2,4 +2,5 @@ from .so3krates_layer import So3kratesLayer from .so3kratace_layer import So3krataceLayer from .so3krates_layer_sparse import SO3kratesLayerSparse +from .itp_layer import ITPLayer from .h_register import get_layer diff --git a/mlff/nn/layer/h_register.py b/mlff/nn/layer/h_register.py index 461eb6d..eaad8f0 100644 --- a/mlff/nn/layer/h_register.py +++ b/mlff/nn/layer/h_register.py @@ -2,6 +2,7 @@ from .so3krates_layer import So3kratesLayer from .so3kratace_layer import So3krataceLayer from .schnet_layer import SchNetLayer +from .itp_layer import ITPLayer def get_layer(name: str, h: Dict): @@ -11,6 +12,8 @@ def get_layer(name: str, h: Dict): return So3krataceLayer(**h) elif name == 'schnet_layer': return SchNetLayer(**h) + elif name == 'itp_layer': + return ITPLayer(**h) elif name == 'spookynet_layer': raise NotImplementedError('SpookyNet not implemented!') return SpookyNetLayer(**h) diff --git a/mlff/nn/layer/itp_layer.py b/mlff/nn/layer/itp_layer.py new file mode 100644 index 0000000..39a9522 --- /dev/null +++ b/mlff/nn/layer/itp_layer.py @@ -0,0 +1,210 @@ +import jax.numpy as jnp +import flax.linen as nn +import jax + +import e3x +from functools import partial +from typing import Optional, Sequence + +from mlff.nn.base.sub_module import BaseSubModule +from mlff.nn.layer.utils import Residual + + +def get_activation_fn(name: str): + return getattr(e3x.nn, name) if name != 'identity' else lambda u: u + + +def aggregate_from_features(features: Sequence, aggregation: str): + if aggregation == 'last': + return features[-1] + elif aggregation == 'concatenation': + return jnp.concatenate(features, axis=-1) + else: + raise ValueError(f'{aggregation} not a valid aggregation for features.') + + +def aggregation_from_connectivity(connectivity: str): + if connectivity == 'independent': + return 'last' + elif connectivity == 'skip': + return 'last' + elif connectivity == 'dense': + return 'concatenation' + else: + ValueError(f'f{connectivity} not a valid connectivity pattern.') + + +class ITPLayer(BaseSubModule): + """Message passing sweep, followed by multiple atom-wise iterated tensor products. + + """ + + filter_num_layers: int = 1 + filter_activation_fn: str = 'identity' + + mp_max_degree: Optional[int] = None + mp_post_res_block: bool = False + mp_post_res_block_activation_fn: str = 'identity' + + itp_max_degree: Optional[int] = None + itp_num_features: Optional[int] = None + itp_num_updates: int = 1 + itp_post_res_block: bool = False + itp_post_res_block_activation_fn: str = 'identity' + itp_connectivity: str = 'skip' # dense, independent + + message_normalization: Optional[str] = None # avg_num_neighbors + avg_num_neighbors: Optional[float] = None + + feature_collection_over_layers: str = 'last' # summation + + include_pseudotensors: bool = False + module_name: str = 'itp_layer' + + def setup(self): + if self.itp_connectivity == 'dense': + if self.itp_max_degree is not None: + if self.itp_max_degree != self.mp_max_degree: + raise ValueError( + f'For {self.itp_connectivity=} maximal degree of tensor products must be equal ' + f'to maximal degree in message passing, but {self.itp_max_degree=} != {self.mp_max_degree=}.' + ) + if self.message_normalization == 'avg_num_neighbors': + if self.avg_num_neighbors is None: + raise ValueError( + f'For {self.message_normalization=} average number of neighbors is required, but it is' + f'{self.avg_num_neighbors}.' + ) + + @nn.compact + def __call__(self, + x: jnp.ndarray, + basis: jnp.ndarray, + cut: jnp.ndarray, + idx_i: jnp.ndarray, + idx_j: jnp.ndarray, + *args, + **kwargs): + """ + + Args: + x (Array): (N, num_features) in first layer + basis (): + cut (): + idx_i (): + idx_j (): + *args (): + **kwargs (): + + Returns: + + """ + num_features = x.shape[-1] + features = [] + # In the first layer x has shape (N, num_features) + if x.ndim == 2: + x = x[:, None, None, :] # (N, 1, 1, num_features) + + # One layer is applied in the e3x.MessagePass to align basis and feature dimension. + for _ in range(self.filter_num_layers - 1): + sigma = get_activation_fn(self.filter_activation_fn) + basis = sigma( + e3x.nn.Dense( + features=num_features + )( + basis + ) + ) + + y = e3x.nn.MessagePass( + include_pseudotensors=self.include_pseudotensors, + max_degree=self.mp_max_degree + )( + inputs=x, + basis=basis, + src_idx=idx_i, + dst_idx=idx_j, + num_segments=len(x) + ) + + if self.message_normalization == 'avg_num_neighbors': + y = jnp.divide(y, + jnp.sqrt(jnp.asarray(self.avg_num_neighbors, dtype=y.dtype)) + ) + + # Dense layer to outer part of skip connection. + z = e3x.nn.Dense(features=num_features)(x) # (N, 1 or 2, (max_degree + 1)^2, num_features) + + # Skip connection around message pass. + y = e3x.nn.add(y + z) # (N, 1 or 2, (max_degree + 1)^2, num_features) + + # Residual block. + if self.mp_post_res_block: + y = Residual( + activation_fn=get_activation_fn(self.mp_post_res_block_activation_fn) + )(y) + + if self.itp_num_features is not None: + y = e3x.nn.Dense( + features=self.itp_num_features + )(y) + + features.append(y) + + for i in range(self.itp_num_updates): + aggregation = aggregation_from_connectivity(self.itp_connectivity) + x_pre_itp = partial(aggregate_from_features, aggregation=aggregation)(features) + + x_itp = e3x.nn.TensorDense( + include_pseudotensors=False if i == self.itp_num_updates - 1 else self.include_pseudotensors, + max_degree=0 if i == self.itp_num_updates - 1 else self.itp_max_degree + )(x_pre_itp) + + if self.itp_post_res_block: + x_itp = Residual( + activation_fn=get_activation_fn(self.itp_post_res_block_activation_fn) + )(x_itp) + + if self.itp_connectivity == 'skip': + if i == self.itp_num_updates - 1: + x_pre_itp = e3x.nn.change_max_degree_or_type( + x_pre_itp, + max_degree=0, + include_pseudotensors=False + ) + + x_itp = e3x.nn.add(x_pre_itp, x_itp) + + features.append(x_itp) + + if self.feature_collection_over_layers == 'last': + x_final = features[-1] + # (N, 1, 1, num_itp_features * num_itp_updates) for dense + # (N, 1, 1, num_itp_updates) for skip and independent + + elif self.feature_collection_over_layers == 'summation': + x_final = e3x.nn.Dense( + num_features, + use_bias=False + )(x) # Input is not part of features so add it by hand. + for x_f in features: + x_final += e3x.nn.Dense( + num_features, + use_bias=False + )( + e3x.nn.change_max_degree_or_type( + x_f, + max_degree=0, + include_pseudotensors=False + ) + ) # (N, 1, 1, num_features) + else: + raise ValueError( + f'{self.feature_collection_over_layers} not a valid argument for `feature_collection_over_layers`.' + ) + + x_final = x_final.squeeze(1).squeeze(1) # (N, num_features) + + return dict( + x=x_final + ) diff --git a/mlff/nn/layer/utils.py b/mlff/nn/layer/utils.py new file mode 100644 index 0000000..df9a617 --- /dev/null +++ b/mlff/nn/layer/utils.py @@ -0,0 +1,81 @@ +import jax +import jax.numpy as jnp +import flax.linen as nn +import e3x + +from typing import Any, Callable, Optional, Sequence + + +class Residual(nn.Module): + """Residual block.""" + + num_blocks: int = 2 + activation_fn: Callable[..., Any] = e3x.nn.activations.silu + use_bias: bool = True + kernel_init: Callable[..., Any] = jax.nn.initializers.lecun_normal() + kernel_init_last_block: Callable[..., Any] = jax.nn.initializers.zeros + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__(self, inputs): + feat = inputs.shape[-1] + x = inputs + + for i in range(self.num_blocks - 1): + x = self.activation_fn(x) + x = e3x.nn.modules.Dense( + feat, + use_bias=self.use_bias, + dtype=self.dtype, + kernel_init=self.kernel_init, + param_dtype=self.param_dtype, + name=f"dense_{i}", + )(x) + + x = e3x.nn.modules.Dense( + feat, + use_bias=self.use_bias, + dtype=self.dtype, + param_dtype=self.param_dtype, + kernel_init=self.kernel_init_last_block, + name=f"dense_{self.num_blocks - 1}", + )(x) + + return e3x.nn.add(inputs + x) + + +class ResidualMLP(nn.Module): + """Residual MLP.""" + + num_residuals: int = 1 + num_blocks_per_residual: int = 2 + use_bias: bool = True + activation_fn: Callable[..., Any] = e3x.nn.activations.silu + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__(self, inputs): + feat = inputs.shape[-1] + x = inputs + for i in range(self.num_residuals): + x = Residual( + self.num_blocks_per_residual, + self.activation_fn, + use_bias=self.use_bias, + dtype=self.dtype, + param_dtype=self.param_dtype, + name=f"residual_{i}", + )(x) + + x = self.activation_fn(x) + x = e3x.nn.modules.Dense( + feat, + use_bias=self.use_bias, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="dense_0", + )(x) + + return x diff --git a/mlff/nn/representation/__init__.py b/mlff/nn/representation/__init__.py index 2894de7..4dc0b62 100644 --- a/mlff/nn/representation/__init__.py +++ b/mlff/nn/representation/__init__.py @@ -2,3 +2,4 @@ from .schnet import init_schnet as SchNet from .so3kratace import init_so3kratace as So3kratACE from .so3krates_sparse import init_so3krates_sparse as SO3kratesSparse +from .itp_net import init_itp_net as ITPNet diff --git a/mlff/nn/representation/itp_net.py b/mlff/nn/representation/itp_net.py new file mode 100644 index 0000000..0e55819 --- /dev/null +++ b/mlff/nn/representation/itp_net.py @@ -0,0 +1,87 @@ +import flax.linen as nn +import jax +from mlff.nn.stacknet import StackNetSparse +from mlff.nn.embed import GeometryEmbedE3x, AtomTypeEmbedSparse +from mlff.nn.layer import ITPLayer +from mlff.nn.observable import EnergySparse +from typing import Optional, Sequence + + +def init_itp_net( + num_features: int = 32, + radial_basis_fn: str = 'reciprocal_bernstein', + num_radial_basis_fn: int = 16, + cutoff_fn: str = 'smooth_cutoff', + cutoff: float = 5., + filter_num_layers: int = 1, + filter_activation_fn: str = 'identity', + mp_max_degree: int = 2, + mp_post_res_block: bool = True, + mp_post_res_block_activation_fn: str = 'identity', + itp_max_degree: int = 2, + itp_num_features: int = 32, + itp_num_updates: int = 3, + itp_post_res_block: bool = True, + itp_post_res_block_activation_fn: str = 'identity', + itp_connectivity: str = 'dense', + message_normalization: Optional[str] = None, + avg_num_neighbors: Optional[float] = None, + feature_collection_over_layers: str = 'final', + include_pseudotensors: bool = False, + output_is_zero_at_init: bool = True, + energy_regression_dim: int = 128, + energy_activation_fn: str = 'identity', + energy_learn_atomic_type_scales: bool = False, + energy_learn_atomic_type_shifts: bool = False, + input_convention: str = 'positions' +): + atom_type_embed = AtomTypeEmbedSparse( + num_features=num_features, + prop_keys=None + ) + geometry_embed = GeometryEmbedE3x( + max_degree=mp_max_degree, + radial_basis_fn=radial_basis_fn, + num_radial_basis_fn=num_radial_basis_fn, + cutoff_fn=cutoff_fn, + cutoff=cutoff, + input_convention=input_convention, + prop_keys=None + ) + + layers = [ITPLayer( + mp_max_degree=mp_max_degree, + filter_num_layers=filter_num_layers, + filter_activation_fn=filter_activation_fn, + mp_post_res_block=mp_post_res_block, + mp_post_res_block_activation_fn=mp_post_res_block_activation_fn, + itp_max_degree=itp_max_degree, + itp_num_features=itp_num_features, + itp_num_updates=itp_num_updates, + itp_post_res_block=itp_post_res_block, + itp_post_res_block_activation_fn=itp_post_res_block_activation_fn, + itp_connectivity=itp_connectivity, + message_normalization=message_normalization, + avg_num_neighbors=avg_num_neighbors, + feature_collection_over_layers=feature_collection_over_layers, + include_pseudotensors=include_pseudotensors + )] + + energy = EnergySparse( + prop_keys=None, + output_is_zero_at_init=output_is_zero_at_init, + regression_dim=energy_regression_dim, + activation_fn=getattr( + nn.activation, energy_activation_fn + ) if energy_activation_fn != 'identity' else lambda u: u, + learn_atomic_type_scales=energy_learn_atomic_type_scales, + learn_atomic_type_shifts=energy_learn_atomic_type_shifts, + ) + + return StackNetSparse( + geometry_embeddings=[geometry_embed], + feature_embeddings=[atom_type_embed], + layers=layers, + observables=[energy], + prop_keys=None + ) diff --git a/setup.py b/setup.py index d30bf15..621b647 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,7 @@ "numpy", "clu", # "jax == 0.4.8", + "e3x", "flax", "jaxopt", "jraph", @@ -42,6 +43,7 @@ "trajectory_to_xyz=mlff.cAPI.mlff_postprocessing:trajectory_to_xyz", "to_mlff_input=mlff.cAPI.mlff_input_processing:to_mlff_input", "train_so3krates_sparse=mlff.CLI.run_training:train_so3krates_sparse", + "train_itp_net=mlff.CLI.run_training_itp_net:train_itp_net", "fine_tune_so3krates_sparse=mlff.CLI.run_fine_tuning:fine_tune_so3krates_sparse", "evaluate_so3krates_sparse=mlff.CLI.run_evaluation:evaluate_so3krates_sparse", "evaluate_so3krates_sparse_on=mlff.CLI.run_evaluation_on:evaluate_so3krates_sparse_on" diff --git a/tests/test_itp_layer.py b/tests/test_itp_layer.py new file mode 100644 index 0000000..d2b4e27 --- /dev/null +++ b/tests/test_itp_layer.py @@ -0,0 +1,355 @@ +import pytest + +import jax +import jax.numpy as jnp + +import numpy.testing as npt + +from mlff.nn.layer import ITPLayer +from mlff.nn import GeometryEmbedE3x + +rng1, rng2 = jax.random.split(jax.random.PRNGKey(0), 2) + +num_nodes = 7 +num_edges = 10 + +num_features = 31 +ipt_num_features = 17 +ipt_num_updates = 3 + +x = jnp.ones((num_nodes, num_features)) +basis = jnp.ones((num_edges, 1, 9, 15)) +cut = jnp.ones((num_edges,)) +idx_i = jax.random.randint(rng1, shape=(num_edges,), minval=0, maxval=num_nodes) +idx_j = jax.random.randint(rng2, shape=(num_edges,), minval=0, maxval=num_nodes) + + +@pytest.mark.parametrize("ipt_connectivity", ['dense', 'skip', 'independent']) +@pytest.mark.parametrize("feature_collection_over_layers", ['last', 'summation']) +def test_init(ipt_connectivity, feature_collection_over_layers): + layer = ITPLayer( + mp_max_degree=2, + filter_num_layers=2, + filter_activation_fn='silu', + mp_post_res_block=True, + mp_post_res_block_activation_fn='silu', + itp_max_degree=2, + itp_num_features=ipt_num_features, + itp_num_updates=ipt_num_updates, + itp_post_res_block=True, + itp_post_res_block_activation_fn='identity', + itp_connectivity=ipt_connectivity, + feature_collection_over_layers=feature_collection_over_layers, + include_pseudotensors=False + ) + + _ = layer.init( + jax.random.PRNGKey(0), + x=x, + basis=basis, + cut=cut, + idx_i=idx_i, + idx_j=idx_j, + ) + + +@pytest.mark.parametrize("ipt_connectivity", ['dense', 'skip', 'independent']) +@pytest.mark.parametrize("feature_collection_over_layers", ['last', 'summation']) +@pytest.mark.parametrize("_ipt_num_features", [None, 13]) +def test_apply(ipt_connectivity, feature_collection_over_layers, _ipt_num_features): + layer = ITPLayer( + mp_max_degree=2, + filter_num_layers=2, + filter_activation_fn='silu', + mp_post_res_block=True, + mp_post_res_block_activation_fn='silu', + itp_max_degree=2, + itp_num_features=_ipt_num_features, + itp_num_updates=ipt_num_updates, + itp_post_res_block=True, + itp_post_res_block_activation_fn='identity', + itp_connectivity=ipt_connectivity, + feature_collection_over_layers=feature_collection_over_layers, + include_pseudotensors=False + ) + + params = layer.init( + jax.random.PRNGKey(0), + x=x, + basis=basis, + cut=cut, + idx_i=idx_i, + idx_j=idx_j, + ) + + output = layer.apply( + params, + x=x, + basis=basis, + cut=cut, + idx_i=idx_i, + idx_j=idx_j, + ) + if feature_collection_over_layers == 'summation': + npt.assert_equal(output.get('x').shape, x.shape) + elif feature_collection_over_layers == 'last': + if ipt_connectivity == 'dense': + if _ipt_num_features is None: + npt.assert_equal( + output.get('x').shape, + (len(x), num_features + num_features * ipt_num_updates) + ) + else: + npt.assert_equal( + output.get('x').shape, + (len(x), _ipt_num_features + _ipt_num_features * ipt_num_updates) + ) + elif ipt_connectivity == 'skip': + if _ipt_num_features is None: + npt.assert_equal(output.get('x').shape, (len(x), num_features)) + else: + npt.assert_equal(output.get('x').shape, (len(x), _ipt_num_features)) + elif ipt_connectivity == 'independent': + if _ipt_num_features is None: + npt.assert_equal(output.get('x').shape, (len(x), num_features)) + else: + npt.assert_equal(output.get('x').shape, (len(x), _ipt_num_features)) + else: + raise RuntimeError + else: + raise RuntimeError + +# +# +# def test_translation_invariance(): +# x = jnp.ones((3, 64)) +# positions = jnp.array([ +# [0., 0., 0.], +# [0., -2., 0.], +# [1., 0.5, 0.], +# ]) +# idx_i = jnp.array([0, 0, 1, 2]) +# idx_j = jnp.array([1, 2, 0, 0]) +# +# geometry_embed = GeometryEmbedSparse(degrees=[1, 2], +# radial_basis_fn='bernstein', +# num_radial_basis_fn=16, +# cutoff_fn='exponential', +# cutoff=2.5, +# input_convention='positions', +# prop_keys=None) +# +# geometry_embed_inputs = dict( +# positions=positions, +# idx_i=idx_i, +# idx_j=idx_j, +# cell=None, +# cell_offset=None +# ) +# +# geometry_embed_inputs_translated = dict( +# positions=positions + jnp.array([2.5, 1.0, -0.7])[None], +# idx_i=idx_i, +# idx_j=idx_j, +# cell=None, +# cell_offset=None +# ) +# +# params = geometry_embed.init( +# jax.random.PRNGKey(0), +# geometry_embed_inputs +# ) +# +# geometry_embed_output = geometry_embed.apply( +# params, +# geometry_embed_inputs +# ) +# +# geometry_embed_output_translated = geometry_embed.apply( +# params, +# geometry_embed_inputs_translated +# ) +# +# so3krates_layer = SO3kratesLayerSparse( +# degrees=[1, 2], +# use_spherical_filter=True, +# num_heads=2, +# num_features_head=16, +# qk_non_linearity=jax.nn.softplus, +# residual_mlp_1=True, +# residual_mlp_2=True, +# layer_normalization_1=False, +# layer_normalization_2=False, +# activation_fn=jax.nn.softplus, +# behave_like_identity_fn_at_init=False +# ) +# +# ev = jax.ops.segment_sum( +# geometry_embed_output.get('ylm_ij'), +# segment_ids=idx_i, +# num_segments=len(x) +# ) +# +# ev_translated = jax.ops.segment_sum( +# geometry_embed_output_translated.get('ylm_ij'), +# segment_ids=idx_i, +# num_segments=len(x) +# ) +# +# so3k_params = so3krates_layer.init( +# jax.random.PRNGKey(0), +# x=x, +# ev=ev, +# rbf_ij=geometry_embed_output.get('rbf_ij'), +# ylm_ij=geometry_embed_output.get('ylm_ij'), +# cut=geometry_embed_output.get('cut'), +# idx_i=idx_i, +# idx_j=idx_j, +# ) +# +# output = so3krates_layer.apply( +# so3k_params, +# x=x, +# ev=ev, +# rbf_ij=geometry_embed_output.get('rbf_ij'), +# ylm_ij=geometry_embed_output.get('ylm_ij'), +# cut=geometry_embed_output.get('cut'), +# idx_i=idx_i, +# idx_j=idx_j, +# ) +# +# output_translated = so3krates_layer.apply( +# so3k_params, +# x=x, +# ev=ev_translated, +# rbf_ij=geometry_embed_output_translated.get('rbf_ij'), +# ylm_ij=geometry_embed_output_translated.get('ylm_ij'), +# cut=geometry_embed_output_translated.get('cut'), +# idx_i=idx_i, +# idx_j=idx_j, +# ) +# +# npt.assert_allclose(output_translated.get('x'), output.get('x')) +# npt.assert_allclose(output_translated.get('ev'), output.get('ev')) +# +# +# def test_rotation_equivariance(): +# from mlff.geometric import get_rotation_matrix +# +# x = jnp.ones((3, 64)) +# positions = jnp.array([ +# [0., 0., 0.], +# [0., -2., 0.], +# [1., 0.5, 0.], +# ]) +# idx_i = jnp.array([0, 0, 1, 2]) +# idx_j = jnp.array([1, 2, 0, 0]) +# +# geometry_embed = GeometryEmbedSparse(degrees=[1, 2], +# radial_basis_fn='bernstein', +# num_radial_basis_fn=16, +# cutoff_fn='exponential', +# cutoff=2.5, +# input_convention='positions', +# prop_keys=None) +# +# geometry_embed_inputs = dict( +# positions=positions, +# idx_i=idx_i, +# idx_j=idx_j, +# cell=None, +# cell_offset=None +# ) +# rot = get_rotation_matrix(euler_axes='xyz', angles=[87, 14, 156], degrees=True) +# geometry_embed_inputs_rotated = dict( +# positions=positions@rot, +# idx_i=idx_i, +# idx_j=idx_j, +# cell=None, +# cell_offset=None +# ) +# +# params = geometry_embed.init( +# jax.random.PRNGKey(0), +# geometry_embed_inputs +# ) +# +# geometry_embed_output = geometry_embed.apply( +# params, +# geometry_embed_inputs +# ) +# +# geometry_embed_output_rotated = geometry_embed.apply( +# params, +# geometry_embed_inputs_rotated +# ) +# +# so3krates_layer = SO3kratesLayerSparse( +# degrees=[1, 2], +# use_spherical_filter=True, +# num_heads=2, +# num_features_head=16, +# qk_non_linearity=jax.nn.softplus, +# residual_mlp_1=True, +# residual_mlp_2=True, +# layer_normalization_1=False, +# layer_normalization_2=False, +# activation_fn=jax.nn.softplus, +# behave_like_identity_fn_at_init=False +# ) +# +# ev = jax.ops.segment_sum( +# geometry_embed_output.get('ylm_ij'), +# segment_ids=idx_i, +# num_segments=len(x) +# ) +# +# ev_rotated = jax.ops.segment_sum( +# geometry_embed_output_rotated.get('ylm_ij'), +# segment_ids=idx_i, +# num_segments=len(x) +# ) +# +# so3k_params = so3krates_layer.init( +# jax.random.PRNGKey(0), +# x=x, +# ev=ev, +# rbf_ij=geometry_embed_output.get('rbf_ij'), +# ylm_ij=geometry_embed_output.get('ylm_ij'), +# cut=geometry_embed_output.get('cut'), +# idx_i=idx_i, +# idx_j=idx_j, +# ) +# +# output = so3krates_layer.apply( +# so3k_params, +# x=x, +# ev=ev, +# rbf_ij=geometry_embed_output.get('rbf_ij'), +# ylm_ij=geometry_embed_output.get('ylm_ij'), +# cut=geometry_embed_output.get('cut'), +# idx_i=idx_i, +# idx_j=idx_j, +# ) +# +# output_rotated = so3krates_layer.apply( +# so3k_params, +# x=x, +# ev=ev_rotated, +# rbf_ij=geometry_embed_output_rotated.get('rbf_ij'), +# ylm_ij=geometry_embed_output_rotated.get('ylm_ij'), +# cut=geometry_embed_output_rotated.get('cut'), +# idx_i=idx_i, +# idx_j=idx_j, +# ) +# +# npt.assert_allclose(output_rotated.get('x'), output.get('x'), atol=1e-5) +# with npt.assert_raises(AssertionError): +# npt.assert_allclose(output_rotated.get('ev'), output.get('ev'), atol=1e-5) +# +# P = jnp.array([2, 0, 1]) +# npt.assert_allclose( +# output_rotated.get('ev')[:, :3][:, P]@rot.T, +# output.get('ev')[:, :3][:, P], +# atol=1e-5 +# )