Skip to content

Commit

Permalink
Merge pull request #285 from apax-hub/feature_model
Browse files Browse the repository at this point in the history
refactored distance computation, added feature model
  • Loading branch information
M-R-Schaefer authored Jul 9, 2024
2 parents 6c8a91a + 67fceae commit 758c7f2
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 57 deletions.
69 changes: 69 additions & 0 deletions apax/layers/distances.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import jax.numpy as jnp
import numpy as np
from jax import vmap

from apax.utils.jax_md_reduced import partition, space


def canonicalize_neighbors(neighbor):
return neighbor.idx if isinstance(neighbor, partition.NeighborList) else neighbor


def disp_fn(ri, rj, perturbation, box):
dR = space.pairwise_displacement(ri, rj)
dR = space.transform(box, dR)

if perturbation is not None:
dR = dR + space.raw_transform(perturbation, dR)
# https://github.com/mir-group/nequip/blob/c56f48fcc9b4018a84e1ed28f762fadd5bc763f1/nequip/nn/_grad_output.py#L267
# https://github.com/sirmarcel/glp/blob/main/glp/calculators/utils.py
# other codes do R = R + strain, not dR
# can be implemented for efficiency
return dR


def get_disp_fn(displacement):
def disp_fn(ri, rj, perturbation, box):
return displacement(ri, rj, perturbation, box=box)

return disp_fn


def make_distance_fn(init_box, inference_disp_fn=None):
"""Model which post processes the output of an atomistic model and
adds empirical energy terms.
"""

if np.all(init_box < 1e-6):
# gas phase training and predicting
displacement_fn = space.free()[0]
displacement = space.map_bond(displacement_fn)
elif inference_disp_fn is None:
# for training on periodic systems
displacement = vmap(disp_fn, (0, 0, None, None), 0)
else:
mappable_displacement_fn = get_disp_fn(inference_disp_fn)
displacement = vmap(mappable_displacement_fn, (0, 0, None, None), 0)

def compute_distances(R, neighbor, box, offsets, perturbation=None):
# Distances
idx = canonicalize_neighbors(neighbor)
idx_i, idx_j = idx[0], idx[1]

# R shape n_atoms x 3
R = R.astype(jnp.float64)
Ri = R[idx_i]
Rj = R[idx_j]

# dr_vec shape: neighbors x 3
if np.all(init_box < 1e-6):
# reverse conventnion to match TF
# distance vector for gas phase training and predicting
dr_vec = displacement(Rj, Ri)
else:
# distance vector for training on periodic systems
dr_vec = displacement(Rj, Ri, perturbation, box)
dr_vec += offsets
return dr_vec, idx

return compute_distances
2 changes: 1 addition & 1 deletion apax/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def build_energy_model(
corrections.append(repulsion)

model = EnergyModel(
atomistic_model,
atomistic_model=atomistic_model,
corrections=corrections,
init_box=init_box,
inference_disp_fn=inference_disp_fn,
Expand Down
97 changes: 42 additions & 55 deletions apax/model/gmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax import Array, vmap
from jax import Array

from apax.layers.descriptor.gaussian_moment_descriptor import GaussianMomentDescriptor
from apax.layers.distances import make_distance_fn
from apax.layers.empirical import EmpiricalEnergyTerm
from apax.layers.masking import mask_by_atom
from apax.layers.properties import stress_times_vol
from apax.layers.readout import AtomisticReadout
from apax.layers.scaling import PerElementScaleShift
from apax.utils.jax_md_reduced import partition, space
from apax.utils.jax_md_reduced import partition
from apax.utils.math import fp64_sum

DisplacementFn = Callable[[Array, Array], Array]
Expand All @@ -23,30 +23,6 @@
log = logging.getLogger(__name__)


def canonicalize_neighbors(neighbor):
return neighbor.idx if isinstance(neighbor, partition.NeighborList) else neighbor


def disp_fn(ri, rj, perturbation, box):
dR = space.pairwise_displacement(ri, rj)
dR = space.transform(box, dR)

if perturbation is not None:
dR = dR + space.raw_transform(perturbation, dR)
# https://github.com/mir-group/nequip/blob/c56f48fcc9b4018a84e1ed28f762fadd5bc763f1/nequip/nn/_grad_output.py#L267
# https://github.com/sirmarcel/glp/blob/main/glp/calculators/utils.py
# other codes do R = R + strain, not dR
# can be implemented for efficiency
return dR


def get_disp_fn(displacement):
def disp_fn(ri, rj, perturbation, box):
return displacement(ri, rj, perturbation, box=box)

return disp_fn


class AtomisticModel(nn.Module):
"""Most basic prediction model.
Allesmbles descriptor, readout (NNs) and output scale-shifting.
Expand All @@ -72,6 +48,37 @@ def __call__(
return output


class FeatureModel(nn.Module):
"""Model wrapps some submodel (e.g. a descriptor) to supply distance computation."""

feature_model: nn.Module = GaussianMomentDescriptor()
init_box: np.array = field(default_factory=lambda: np.array([0.0, 0.0, 0.0]))
inference_disp_fn: Any = None

def setup(self):
self.compute_distances = make_distance_fn(self.init_box, self.inference_disp_fn)

def __call__(
self,
R: Array,
Z: Array,
neighbor: Union[partition.NeighborList, Array],
box,
offsets,
perturbation=None,
):
dr_vec, idx = self.compute_distances(
R,
neighbor,
box,
offsets,
perturbation,
)

features = self.feature_model(dr_vec, Z, idx)
return features


class EnergyModel(nn.Module):
"""Model which post processes the output of an atomistic model and
adds empirical energy terms.
Expand All @@ -83,16 +90,7 @@ class EnergyModel(nn.Module):
inference_disp_fn: Any = None

def setup(self):
if np.all(self.init_box < 1e-6):
# gas phase training and predicting
displacement_fn = space.free()[0]
self.displacement = space.map_bond(displacement_fn)
elif self.inference_disp_fn is None:
# for training on periodic systems
self.displacement = vmap(disp_fn, (0, 0, None, None), 0)
else:
mappable_displacement_fn = get_disp_fn(self.inference_disp_fn)
self.displacement = vmap(mappable_displacement_fn, (0, 0, None, None), 0)
self.compute_distances = make_distance_fn(self.init_box, self.inference_disp_fn)

def __call__(
self,
Expand All @@ -103,24 +101,13 @@ def __call__(
offsets,
perturbation=None,
):
# Distances
idx = canonicalize_neighbors(neighbor)
idx_i, idx_j = idx[0], idx[1]

# R shape n_atoms x 3
R = R.astype(jnp.float64)
Ri = R[idx_i]
Rj = R[idx_j]

# dr_vec shape: neighbors x 3
if np.all(self.init_box < 1e-6):
# reverse conventnion to match TF
# distance vector for gas phase training and predicting
dr_vec = self.displacement(Rj, Ri)
else:
# distance vector for training on periodic systems
dr_vec = self.displacement(Rj, Ri, perturbation, box)
dr_vec += offsets
dr_vec, idx = self.compute_distances(
R,
neighbor,
box,
offsets,
perturbation,
)

# Model Core
# shape Natoms
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/data/test_input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jax import vmap

from apax.data.preprocessing import compute_nl
from apax.model.gmnn import disp_fn
from apax.layers.distances import disp_fn
from apax.utils.convert import atoms_to_inputs, atoms_to_labels
from apax.utils.data import split_atoms, split_idxs
from apax.utils.random import seed_py_np_tf
Expand Down

0 comments on commit 758c7f2

Please sign in to comment.